# 🧵 DeepSeek-Coder-V2: YARN Context Extension Deep Dive

## 🎯 Learning Objectives

Master **YARN (Yet Another RoPE extensioN)** technique được sử dụng trong DeepSeek-Coder-V2 để mở rộng context length từ 16K lên 128K tokens:

1. **RoPE Fundamentals**: Hiểu Rotary Position Embedding
2. **YARN Theory**: Cơ chế mở rộng context length
3. **Implementation Details**: Code từ cơ bản đến nâng cao
4. **Performance Analysis**: "Needle in a Haystack" evaluation
5. **Long Context Applications**: Ứng dụng trong code understanding

## 📚 Paper References

**Section 3.4: Long Context Extension**
> "Following DeepSeek-V2, we extend the context length of DeepSeek-Coder-V2 to 128K using Yarn (Peng et al., 2023). The hyper-parameters of YARN are the same as DeepSeek-V2: the scale s to 40, α to 1, β to 32."

**Key Statistics:**
- **Original context**: 16K tokens (DeepSeek-Coder)
- **Extended context**: 128K tokens (8x extension)
- **YARN parameters**: s=40, α=1, β=32
- **Training stages**: 32K (1000 steps) → 128K (1000 steps)

## 🔧 Environment Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional, Dict, List
import math
import warnings
warnings.filterwarnings('ignore')

# Plotting setup
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("🧵 YARN Context Extension Learning Environment Ready!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 🌀 RoPE (Rotary Position Embedding) Foundation

### 💡 What is RoPE?

**Rotary Position Embedding** encode position information bằng cách rotate query và key vectors trong complex space.

### 🔑 Mathematical Foundation:

Cho position $m$ và dimension $d$, RoPE applies rotation:

$$f_q(x_m, m) = (W_q x_m) \otimes e^{im\theta}$$
$$f_k(x_n, n) = (W_k x_n) \otimes e^{in\theta}$$

Trong đó:
- $\theta_j = 10000^{-2j/d}$ for dimension $j$
- $\otimes$ là element-wise complex multiplication
- Relative position được encode trong dot product: $q^T k = \text{Re}(q^* k e^{i(m-n)\theta})$

### 🎯 RoPE Benefits:
1. **Relative positioning**: Attention depends on relative distance
2. **Translation invariance**: Shifting sequence doesn't change attention pattern
3. **Efficient implementation**: No additional parameters needed

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """Standard RoPE implementation
    
    Based on "RoFormer: Enhanced Transformer with Rotary Position Embedding"
    """
    
    def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        
        # Precompute frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # Cache for efficiency
        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None
    
    def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
        """Update cosine and sine cache for given sequence length"""
        if seq_len > self._seq_len_cached or self._cos_cached is None:
            self._seq_len_cached = seq_len
            
            # Position indices
            t = torch.arange(seq_len, device=device, dtype=dtype)
            
            # Frequency matrix [seq_len, dim//2]
            freqs = torch.outer(t, self.inv_freq.to(device))
            
            # Combine sin and cos for all dimensions [seq_len, dim]
            emb = torch.cat((freqs, freqs), dim=-1)
            
            self._cos_cached = emb.cos()
            self._sin_cached = emb.sin()
    
    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        """Rotate half the dimensions
        
        [x0, x1, x2, x3, ...] -> [-x1, x0, -x3, x2, ...]
        """
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    
    def apply_rotary_emb(self, q: torch.Tensor, k: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply rotary embedding to query and key tensors
        
        Args:
            q: Query tensor [batch_size, num_heads, seq_len, head_dim]
            k: Key tensor [batch_size, num_heads, seq_len, head_dim]
            seq_len: Sequence length
            
        Returns:
            Rotated query and key tensors
        """
        self._update_cos_sin_cache(seq_len, q.device, q.dtype)
        
        # Get cos/sin for current sequence length
        cos = self._cos_cached[:seq_len]
        sin = self._sin_cached[:seq_len]
        
        # Apply rotation
        q_rotated = q * cos + self.rotate_half(q) * sin
        k_rotated = k * cos + self.rotate_half(k) * sin
        
        return q_rotated, k_rotated
    
    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass"""
        seq_len = q.size(-2)
        return self.apply_rotary_emb(q, k, seq_len)

# Demo standard RoPE
print("🌀 Testing Standard RoPE:")
print("=" * 40)

# Parameters
batch_size, num_heads, seq_len, head_dim = 2, 8, 16, 64
rope = RotaryPositionalEmbedding(head_dim, max_seq_len=2048)

# Create query and key tensors
q = torch.randn(batch_size, num_heads, seq_len, head_dim)
k = torch.randn(batch_size, num_heads, seq_len, head_dim)

print(f"Input shapes - Q: {q.shape}, K: {k.shape}")

# Apply RoPE
q_rot, k_rot = rope(q, k)
print(f"Output shapes - Q: {q_rot.shape}, K: {k_rot.shape}")

# Verify relative position property
# Attention should be translation invariant
attn_original = torch.matmul(q_rot, k_rot.transpose(-2, -1))

# Shift by 1 position
q_shifted = q[:, :, 1:]
k_shifted = k[:, :, 1:]
q_shifted_rot, k_shifted_rot = rope.apply_rotary_emb(q_shifted, k_shifted, seq_len-1)
attn_shifted = torch.matmul(q_shifted_rot, k_shifted_rot.transpose(-2, -1))

# Compare diagonals (should be similar due to translation invariance)
diag_original = torch.diagonal(attn_original[0, 0], offset=0)
diag_shifted = torch.diagonal(attn_shifted[0, 0], offset=0)

print(f"✅ RoPE applied successfully")
print(f"📊 Translation invariance check:")
print(f"   Original diagonal mean: {diag_original.mean().item():.4f}")
print(f"   Shifted diagonal mean: {diag_shifted.mean().item():.4f}")
print(f"   Relative difference: {abs(diag_original.mean() - diag_shifted.mean()).item():.6f}")

## 🧵 YARN Theory & Implementation

### 💡 What is YARN?

**YARN (Yet Another RoPE extensioN)** là technique để mở rộng context length của pre-trained models sử dụng RoPE position embeddings.

### 🔑 Core Concepts:

1. **Frequency Scaling**: Scale down frequencies để accommodate longer sequences
2. **Attention Scaling**: Scale attention weights để maintain distribution
3. **Interpolation**: Interpolate between scaled and original frequencies

### 📊 YARN Formula:

For frequency dimension $i$ và scale factor $s$:

$$\theta_i' = \begin{cases}
\theta_i / s & \text{if } i < d \cdot \alpha \\
\theta_i & \text{if } i \geq d \cdot \beta \\
\text{interpolate}(\theta_i/s, \theta_i) & \text{otherwise}
\end{cases}$$

Trong đó:
- $s$: Scale factor (40 in DeepSeek-V2)
- $\alpha$: Low frequency threshold (1)
- $\beta$: High frequency threshold (32)

In [None]:
class YARNRotaryEmbedding(nn.Module):
    """YARN: Yet Another RoPE extensioN
    
    Implementation based on "YaRN: Efficient Context Window Extension of Large Language Models"
    with DeepSeek-V2 hyperparameters
    """
    
    def __init__(
        self,
        dim: int,
        max_seq_len: int = 2048,
        base: float = 10000.0,
        scale: float = 1.0,  # Context extension scale factor
        original_max_seq_len: int = 2048,
        # YARN hyperparameters (DeepSeek-V2 values)
        alpha: float = 1.0,
        beta: float = 32.0,
        extrapolation_factor: float = 1.0,
        finetuned: bool = False
    ):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        self.scale = scale
        self.original_max_seq_len = original_max_seq_len
        self.alpha = alpha
        self.beta = beta
        self.extrapolation_factor = extrapolation_factor
        self.finetuned = finetuned
        
        # Compute YARN frequencies
        self.inv_freq = self._compute_yarn_frequencies()
        self.register_buffer('_inv_freq', self.inv_freq)
        
        # Cache
        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None
    
    def _compute_yarn_frequencies(self) -> torch.Tensor:
        """Compute YARN-modified frequencies"""
        # Original RoPE frequencies
        freq_base = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
        
        if self.scale <= 1.0:
            # No scaling needed
            return freq_base
        
        # YARN frequency modification
        yarn_freqs = torch.zeros_like(freq_base)
        
        for i, freq in enumerate(freq_base):
            # Determine frequency band
            dim_ratio = 2 * i / self.dim  # Position in frequency spectrum [0, 1]
            
            if dim_ratio < self.alpha:
                # Low frequency: apply full scaling
                yarn_freqs[i] = freq / self.scale
            elif dim_ratio >= self.beta:
                # High frequency: no scaling
                yarn_freqs[i] = freq
            else:
                # Mid frequency: interpolate
                # Linear interpolation between scaled and original
                interp_factor = (dim_ratio - self.alpha) / (self.beta - self.alpha)
                scaled_freq = freq / self.scale
                yarn_freqs[i] = scaled_freq * (1 - interp_factor) + freq * interp_factor
        
        return yarn_freqs
    
    def _get_attention_scale(self, seq_len: int) -> float:
        """Get attention scaling factor for YARN"""
        if seq_len <= self.original_max_seq_len:
            return 1.0
        
        # Scale attention based on sequence length extension
        extension_ratio = seq_len / self.original_max_seq_len
        
        if self.finetuned:
            # For finetuned models, use gentler scaling
            return 1.0 / math.sqrt(extension_ratio)
        else:
            # For base models, use logarithmic scaling
            return 1.0 / math.log(extension_ratio + 1)
    
    def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
        """Update cosine/sine cache with YARN modifications"""
        if seq_len > self._seq_len_cached or self._cos_cached is None:
            self._seq_len_cached = seq_len
            
            # Position indices
            t = torch.arange(seq_len, device=device, dtype=dtype)
            
            # Apply extrapolation factor for very long sequences
            if seq_len > self.original_max_seq_len:
                # Use extrapolation factor to reduce high-frequency noise
                t = t * self.extrapolation_factor
            
            # Frequency matrix with YARN frequencies
            freqs = torch.outer(t, self.inv_freq.to(device))
            
            # Combine for all dimensions
            emb = torch.cat((freqs, freqs), dim=-1)
            
            self._cos_cached = emb.cos()
            self._sin_cached = emb.sin()
    
    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        """Rotate half dimensions"""
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    
    def apply_rotary_emb(
        self, 
        q: torch.Tensor, 
        k: torch.Tensor, 
        seq_len: int,
        apply_attention_scaling: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor, float]:
        """Apply YARN rotary embedding
        
        Returns:
            q_rotated, k_rotated, attention_scale
        """
        self._update_cos_sin_cache(seq_len, q.device, q.dtype)
        
        # Get cos/sin
        cos = self._cos_cached[:seq_len]
        sin = self._sin_cached[:seq_len]
        
        # Apply rotation
        q_rotated = q * cos + self.rotate_half(q) * sin
        k_rotated = k * cos + self.rotate_half(k) * sin
        
        # Get attention scaling
        attention_scale = 1.0
        if apply_attention_scaling:
            attention_scale = self._get_attention_scale(seq_len)
        
        return q_rotated, k_rotated, attention_scale
    
    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, float]:
        """Forward pass"""
        seq_len = q.size(-2)
        return self.apply_rotary_emb(q, k, seq_len)

# Demo YARN with DeepSeek-V2 parameters
print("🧵 Testing YARN with DeepSeek-V2 Parameters:")
print("=" * 50)

# DeepSeek-V2 YARN configuration
deepseek_yarn_config = {
    'scale': 40.0,  # s=40 in paper
    'alpha': 1.0,   # α=1 in paper  
    'beta': 32.0,   # β=32 in paper
    'original_max_seq_len': 16384,  # 16K original
    'max_seq_len': 131072,  # 128K extended
}

# Create YARN embedding
yarn = YARNRotaryEmbedding(
    dim=64,
    **deepseek_yarn_config
)

print(f"📊 YARN Configuration:")
for key, value in deepseek_yarn_config.items():
    print(f"   {key}: {value}")

# Test at different sequence lengths
test_lengths = [1024, 16384, 32768, 65536, 131072]  # 1K to 128K

print(f"\n🧪 Testing YARN at Different Sequence Lengths:")
print(f"{'Length':<8} {'Scale':<8} {'Status':<15}")
print("-" * 35)

for seq_len in test_lengths:
    # Test with smaller tensors to avoid memory issues
    q_test = torch.randn(1, 1, min(seq_len, 1024), 64)  # Limit to 1K for demo
    k_test = torch.randn(1, 1, min(seq_len, 1024), 64)
    
    try:
        q_rot, k_rot, attn_scale = yarn.apply_rotary_emb(q_test, k_test, seq_len)
        status = "✅ Success"
    except Exception as e:
        attn_scale = float('nan')
        status = f"❌ Error: {str(e)[:10]}..."
    
    print(f"{seq_len:<8} {attn_scale:<8.4f} {status:<15}")

# Visualize frequency modifications
def visualize_yarn_frequencies():
    """Visualize how YARN modifies RoPE frequencies"""
    
    dim = 128
    
    # Original RoPE frequencies
    original_rope = RotaryPositionalEmbedding(dim)
    original_freqs = original_rope.inv_freq.numpy()
    
    # YARN frequencies with different scales
    scales = [1.0, 4.0, 16.0, 40.0]  # Include DeepSeek-V2's scale=40
    
    plt.figure(figsize=(15, 10))
    
    # Plot frequency modifications
    for i, scale in enumerate(scales):
        yarn_test = YARNRotaryEmbedding(
            dim=dim, 
            scale=scale, 
            alpha=1.0, 
            beta=32.0
        )
        yarn_freqs = yarn_test.inv_freq.numpy()
        
        plt.subplot(2, 2, i+1)
        
        freq_indices = np.arange(len(original_freqs))
        
        plt.plot(freq_indices, original_freqs, 'b-', label='Original RoPE', linewidth=2)
        plt.plot(freq_indices, yarn_freqs, 'r-', label=f'YARN (scale={scale})', linewidth=2)
        
        # Mark alpha and beta boundaries
        alpha_idx = int(1.0 * len(freq_indices))
        beta_idx = int(32.0 * len(freq_indices))
        
        plt.axvline(alpha_idx, color='green', linestyle='--', alpha=0.7, label=f'α={1.0}')
        plt.axvline(beta_idx, color='orange', linestyle='--', alpha=0.7, label=f'β={32.0}')
        
        plt.yscale('log')
        plt.xlabel('Frequency Index')
        plt.ylabel('Inverse Frequency')
        plt.title(f'YARN Frequency Modification (Scale={scale})')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

visualize_yarn_frequencies()

print("\n🔍 Key YARN Insights:")
print("• Low frequencies (α<1): Full scaling applied")
print("• High frequencies (β≥32): No scaling (preserve fine details)")
print("• Mid frequencies: Interpolated scaling")
print("• DeepSeek-V2 uses aggressive scale=40 for 8x context extension")

## 🔍 "Needle in a Haystack" Evaluation

### 🎯 What is "Needle in a Haystack"?

Test khả năng của model tìm specific information trong long context bằng cách:
1. **Insert "needle"**: Thêm specific fact vào random position
2. **Create "haystack"**: Surround với irrelevant text
3. **Query**: Ask model to retrieve the needle
4. **Evaluate**: Measure accuracy across different positions và context lengths

### 📊 DeepSeek-V2 Results (Figure 2):
Model đạt >95% accuracy across all context lengths up to 128K tokens

In [None]:
class NeedleInHaystackEvaluator:
    """Needle in a Haystack evaluation for long context models"""
    
    def __init__(self, model_name: str = "DeepSeek-Coder-V2"):
        self.model_name = model_name
        
        # Needle templates
        self.needles = [
            "The secret code is: ALPHA-7429",
            "Remember this number: 185304",
            "The password is: quantum_bridge_2024",
            "Important: The key is stored in vault-789",
            "Critical information: Project Nebula status is ACTIVE"
        ]
        
        # Haystack text (code-like content for DeepSeek-Coder)
        self.haystack_templates = [
            "def process_data(input_file):\n    with open(input_file, 'r') as f:\n        data = f.read()\n    return data.strip()",
            "class DataProcessor:\n    def __init__(self, config):\n        self.config = config\n    def run(self):\n        pass",
            "import numpy as np\nimport pandas as pd\nfrom sklearn.model_selection import train_test_split",
            "# Configuration settings\nDATABASE_URL = 'postgresql://localhost/mydb'\nAPI_KEY = 'test_key_123'",
            "async function fetchData(url) {\n    const response = await fetch(url);\n    return response.json();\n}"
        ]
    
    def generate_haystack(self, target_length: int) -> str:
        """Generate haystack text of target length"""
        haystack = ""
        
        while len(haystack) < target_length:
            # Add random code snippet
            template = np.random.choice(self.haystack_templates)
            haystack += template + "\n\n"
            
            # Add some random variables/comments
            if np.random.random() > 0.7:
                haystack += f"# Random comment {np.random.randint(1000, 9999)}\n"
            
            if np.random.random() > 0.8:
                haystack += f"variable_{np.random.randint(100, 999)} = {np.random.randint(1, 100)}\n"
        
        return haystack[:target_length]
    
    def create_test_case(
        self, 
        context_length: int, 
        needle_position: float
    ) -> Dict[str, str]:
        """Create a single needle-in-haystack test case
        
        Args:
            context_length: Total context length in characters
            needle_position: Position of needle (0.0 = start, 1.0 = end)
        """
        # Select random needle
        needle = np.random.choice(self.needles)
        
        # Generate haystack
        haystack = self.generate_haystack(context_length - len(needle) - 100)  # Leave room
        
        # Calculate insertion position
        insert_pos = int(len(haystack) * needle_position)
        
        # Insert needle
        context = (
            haystack[:insert_pos] + 
            "\n" + needle + "\n" + 
            haystack[insert_pos:]
        )
        
        # Create query based on needle type
        if "code" in needle.lower():
            query = "What is the secret code mentioned in the text?"
        elif "number" in needle.lower():
            query = "What number should I remember?"
        elif "password" in needle.lower():
            query = "What is the password?"
        elif "vault" in needle.lower():
            query = "Where is the key stored?"
        else:
            query = "What is the status of Project Nebula?"
        
        return {
            'context': context,
            'query': query,
            'needle': needle,
            'expected_answer': needle.split(': ')[1] if ': ' in needle else needle,
            'position': needle_position,
            'length': len(context)
        }
    
    def simulate_model_performance(
        self, 
        context_length: int, 
        needle_position: float,
        use_yarn: bool = True
    ) -> float:
        """Simulate model performance on needle-in-haystack task
        
        Returns accuracy score (0.0 to 1.0)
        """
        # Simulate DeepSeek-V2 performance based on paper results
        base_accuracy = 0.95  # Paper reports >95% accuracy
        
        if not use_yarn:
            # Without YARN, performance degrades significantly for long contexts
            if context_length > 16384:  # Beyond original context length
                degradation = min(0.8, (context_length - 16384) / 50000)
                base_accuracy *= (1 - degradation)
        
        # Position effects (harder to find needle at the middle)
        position_penalty = 0.05 * abs(needle_position - 0.5) * 2  # Max penalty at middle
        
        # Length effects (slight degradation for very long contexts)
        length_penalty = min(0.1, context_length / 1000000)  # Very gentle degradation
        
        # Add some randomness
        noise = np.random.normal(0, 0.02)  # Small noise
        
        accuracy = base_accuracy - position_penalty - length_penalty + noise
        return np.clip(accuracy, 0.0, 1.0)
    
    def run_evaluation(
        self,
        context_lengths: List[int],
        position_percentiles: List[float],
        num_trials: int = 5
    ) -> Dict[str, np.ndarray]:
        """Run comprehensive needle-in-haystack evaluation"""
        
        results_yarn = np.zeros((len(context_lengths), len(position_percentiles)))
        results_no_yarn = np.zeros((len(context_lengths), len(position_percentiles)))
        
        print(f"🔍 Running Needle-in-Haystack Evaluation...")
        print(f"   Context lengths: {context_lengths}")
        print(f"   Position percentiles: {position_percentiles}")
        print(f"   Trials per condition: {num_trials}")
        
        for i, length in enumerate(context_lengths):
            for j, position in enumerate(position_percentiles):
                # Multiple trials for averaging
                yarn_scores = []
                no_yarn_scores = []
                
                for trial in range(num_trials):
                    # With YARN
                    yarn_score = self.simulate_model_performance(
                        length, position, use_yarn=True
                    )
                    yarn_scores.append(yarn_score)
                    
                    # Without YARN
                    no_yarn_score = self.simulate_model_performance(
                        length, position, use_yarn=False
                    )
                    no_yarn_scores.append(no_yarn_score)
                
                results_yarn[i, j] = np.mean(yarn_scores)
                results_no_yarn[i, j] = np.mean(no_yarn_scores)
        
        return {
            'yarn': results_yarn,
            'no_yarn': results_no_yarn,
            'context_lengths': context_lengths,
            'positions': position_percentiles
        }

# Run needle-in-haystack evaluation
print("🎯 Needle-in-Haystack Evaluation:")
print("=" * 40)

evaluator = NeedleInHaystackEvaluator()

# Test parameters
context_lengths = [1000, 4000, 8000, 16000, 32000, 64000, 128000]  # Up to 128K
position_percentiles = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

# Generate a sample test case
sample_test = evaluator.create_test_case(context_length=2000, needle_position=0.5)
print(f"📋 Sample Test Case:")
print(f"   Query: {sample_test['query']}")
print(f"   Needle: {sample_test['needle']}")
print(f"   Expected: {sample_test['expected_answer']}")
print(f"   Context length: {sample_test['length']} chars")
print(f"   Position: {sample_test['position']:.1%}")

# Run evaluation (reduced for demo)
eval_results = evaluator.run_evaluation(
    context_lengths=[1000, 8000, 16000, 64000, 128000],
    position_percentiles=[0.0, 0.25, 0.5, 0.75, 1.0],
    num_trials=3
)

print(f"\n✅ Evaluation completed!")

## 📊 Visualizing YARN Performance

### 🎨 Performance Heatmaps & Analysis

In [None]:
def visualize_needle_in_haystack_results(results: Dict[str, np.ndarray]):
    """Visualize needle-in-haystack evaluation results"""
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    context_lengths = results['context_lengths']
    positions = results['positions']
    
    # 1. YARN Performance Heatmap
    ax1 = axes[0, 0]
    im1 = ax1.imshow(results['yarn'], cmap='RdYlGn', vmin=0.5, vmax=1.0, aspect='auto')
    ax1.set_title('DeepSeek-V2 with YARN\nNeedle-in-Haystack Performance')
    ax1.set_xlabel('Document Position (%)')
    ax1.set_ylabel('Context Length')
    
    # Set ticks
    ax1.set_xticks(range(len(positions)))
    ax1.set_xticklabels([f'{p:.0%}' for p in positions])
    ax1.set_yticks(range(len(context_lengths)))
    ax1.set_yticklabels([f'{l//1000}K' for l in context_lengths])
    
    # Add colorbar
    plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04, label='Accuracy')
    
    # 2. No-YARN Performance Heatmap
    ax2 = axes[0, 1]
    im2 = ax2.imshow(results['no_yarn'], cmap='RdYlGn', vmin=0.5, vmax=1.0, aspect='auto')
    ax2.set_title('Without YARN\nNeedle-in-Haystack Performance')
    ax2.set_xlabel('Document Position (%)')
    ax2.set_ylabel('Context Length')
    
    ax2.set_xticks(range(len(positions)))
    ax2.set_xticklabels([f'{p:.0%}' for p in positions])
    ax2.set_yticks(range(len(context_lengths)))
    ax2.set_yticklabels([f'{l//1000}K' for l in context_lengths])
    
    plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04, label='Accuracy')
    
    # 3. Performance vs Context Length
    ax3 = axes[1, 0]
    
    # Average across all positions
    yarn_avg = results['yarn'].mean(axis=1)
    no_yarn_avg = results['no_yarn'].mean(axis=1)
    
    ax3.plot([l//1000 for l in context_lengths], yarn_avg, 'g-o', linewidth=2, 
             markersize=8, label='With YARN')
    ax3.plot([l//1000 for l in context_lengths], no_yarn_avg, 'r-s', linewidth=2, 
             markersize=8, label='Without YARN')
    
    ax3.set_xlabel('Context Length (K tokens)')
    ax3.set_ylabel('Average Accuracy')
    ax3.set_title('Performance vs Context Length')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim(0.5, 1.0)
    
    # Add DeepSeek-V2 reference line
    ax3.axhline(y=0.95, color='blue', linestyle='--', alpha=0.7, 
                label='DeepSeek-V2 Target (95%)')
    ax3.legend()
    
    # 4. Performance vs Position
    ax4 = axes[1, 1]
    
    # Performance at longest context (128K)
    longest_ctx_idx = -1  # Last index (longest context)
    yarn_pos = results['yarn'][longest_ctx_idx, :]
    no_yarn_pos = results['no_yarn'][longest_ctx_idx, :]
    
    position_labels = [f'{p:.0%}' for p in positions]
    
    ax4.plot(position_labels, yarn_pos, 'g-o', linewidth=2, markersize=8, label='With YARN')
    ax4.plot(position_labels, no_yarn_pos, 'r-s', linewidth=2, markersize=8, label='Without YARN')
    
    ax4.set_xlabel('Needle Position in Document')
    ax4.set_ylabel('Accuracy')
    ax4.set_title(f'Performance vs Position\n(Context Length: {context_lengths[-1]//1000}K)')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    ax4.set_ylim(0.5, 1.0)
    ax4.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Performance summary
    print("\n📊 Performance Summary:")
    print("=" * 40)
    
    print(f"YARN Results:")
    print(f"   Average accuracy: {results['yarn'].mean():.2%}")
    print(f"   Min accuracy: {results['yarn'].min():.2%}")
    print(f"   Max accuracy: {results['yarn'].max():.2%}")
    print(f"   At 128K context: {results['yarn'][-1].mean():.2%}")
    
    print(f"\nWithout YARN:")
    print(f"   Average accuracy: {results['no_yarn'].mean():.2%}")
    print(f"   Min accuracy: {results['no_yarn'].min():.2%}")
    print(f"   Max accuracy: {results['no_yarn'].max():.2%}")
    print(f"   At 128K context: {results['no_yarn'][-1].mean():.2%}")
    
    improvement = (results['yarn'].mean() - results['no_yarn'].mean()) / results['no_yarn'].mean() * 100
    print(f"\n🚀 YARN Improvement: {improvement:.1f}%")

# Visualize results
visualize_needle_in_haystack_results(eval_results)

## 🧪 YARN Training Simulation

### 🏋️ Simulating DeepSeek-V2's Two-Stage Training

Theo paper:
1. **Stage 1**: 32K context, 1152 batch size, 1000 steps
2. **Stage 2**: 128K context, 288 batch size, 1000 steps

In [None]:
class YARNTrainingSimulator:
    """Simulate YARN training process for context extension"""
    
    def __init__(self, original_context_len: int = 16384):
        self.original_context_len = original_context_len
        self.training_stages = [
            {
                'name': 'Stage 1: Gradual Extension',
                'context_length': 32768,  # 32K
                'batch_size': 1152,
                'steps': 1000,
                'learning_rate': 1e-5
            },
            {
                'name': 'Stage 2: Full Extension', 
                'context_length': 131072,  # 128K
                'batch_size': 288,
                'steps': 1000,
                'learning_rate': 5e-6
            }
        ]
    
    def simulate_training_stage(
        self, 
        stage_config: Dict, 
        initial_performance: float = 0.85
    ) -> List[float]:
        """Simulate training performance over steps"""
        
        steps = stage_config['steps']
        context_len = stage_config['context_length']
        
        # Performance starts lower for longer contexts
        context_penalty = (context_len / self.original_context_len - 1) * 0.1
        start_perf = max(0.7, initial_performance - context_penalty)
        
        # Target performance (based on paper results)
        target_perf = 0.95
        
        # Simulate learning curve
        performance_curve = []
        
        for step in range(steps + 1):
            # Exponential improvement with some noise
            progress = 1 - np.exp(-3 * step / steps)  # Exponential approach
            current_perf = start_perf + (target_perf - start_perf) * progress
            
            # Add training noise
            noise = np.random.normal(0, 0.01)
            current_perf = np.clip(current_perf + noise, 0.6, 1.0)
            
            performance_curve.append(current_perf)
        
        return performance_curve
    
    def run_full_training(self) -> Dict[str, List[float]]:
        """Run complete YARN training simulation"""
        
        print("🏋️ Simulating YARN Training Process:")
        print("=" * 50)
        
        all_results = {}
        initial_perf = 0.90  # Start with good base model performance
        
        for i, stage in enumerate(self.training_stages):
            print(f"\n📈 {stage['name']}:")
            print(f"   Context Length: {stage['context_length']:,} tokens")
            print(f"   Batch Size: {stage['batch_size']}")
            print(f"   Steps: {stage['steps']}")
            print(f"   Learning Rate: {stage['learning_rate']}")
            
            # Simulate training
            performance = self.simulate_training_stage(stage, initial_perf)
            all_results[f'stage_{i+1}'] = performance
            
            # Update initial performance for next stage
            initial_perf = performance[-1]
            
            print(f"   Final Performance: {performance[-1]:.2%}")
        
        return all_results
    
    def visualize_training(self, results: Dict[str, List[float]]):
        """Visualize training progress"""
        
        fig, axes = plt.subplots(2, 2, figsize=(16, 10))
        
        # 1. Training curves
        ax1 = axes[0, 0]
        
        colors = ['blue', 'green']
        for i, (stage_name, performance) in enumerate(results.items()):
            steps = range(len(performance))
            ax1.plot(steps, performance, color=colors[i], linewidth=2, 
                    label=f'Stage {i+1}', alpha=0.8)
        
        ax1.set_xlabel('Training Steps')
        ax1.set_ylabel('Performance (Accuracy)')
        ax1.set_title('YARN Training Progress')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        ax1.set_ylim(0.7, 1.0)
        
        # Add target line
        ax1.axhline(y=0.95, color='red', linestyle='--', alpha=0.7, 
                   label='Target (95%)')
        ax1.legend()
        
        # 2. Context length progression
        ax2 = axes[0, 1]
        
        context_lengths = [16, 32, 128]  # K tokens
        final_performances = [0.95, results['stage_1'][-1], results['stage_2'][-1]]
        
        bars = ax2.bar(range(len(context_lengths)), final_performances, 
                      color=['gray', 'blue', 'green'], alpha=0.7)
        
        ax2.set_xlabel('Context Length (K tokens)')
        ax2.set_ylabel('Final Performance')
        ax2.set_title('Performance vs Context Length')
        ax2.set_xticks(range(len(context_lengths)))
        ax2.set_xticklabels([f'{ctx}K' for ctx in context_lengths])
        ax2.set_ylim(0.8, 1.0)
        
        # Add value labels
        for bar, value in zip(bars, final_performances):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height + 0.005,
                    f'{value:.2%}', ha='center', va='bottom', fontweight='bold')
        
        # 3. Memory usage simulation
        ax3 = axes[1, 0]
        
        context_lens = [16, 32, 64, 128]  # K tokens
        # Memory scales quadratically with sequence length (attention computation)
        memory_usage = [ctx**2 / 16**2 for ctx in context_lens]  # Relative to 16K
        
        ax3.plot(context_lens, memory_usage, 'ro-', linewidth=2, markersize=8)
        ax3.set_xlabel('Context Length (K tokens)')
        ax3.set_ylabel('Relative Memory Usage')
        ax3.set_title('Memory Scaling with Context Length')
        ax3.grid(True, alpha=0.3)
        ax3.set_yscale('log')
        
        # Add annotations
        for ctx, mem in zip(context_lens, memory_usage):
            ax3.annotate(f'{mem:.1f}x', (ctx, mem), 
                        textcoords="offset points", xytext=(0,10), ha='center')
        
        # 4. Batch size adjustments
        ax4 = axes[1, 1]
        
        stages = ['Original\n(16K)', 'Stage 1\n(32K)', 'Stage 2\n(128K)']
        batch_sizes = [2048, 1152, 288]  # Estimated original, then from paper
        context_sizes = [16, 32, 128]
        
        # Effective tokens per batch
        effective_tokens = [b * c for b, c in zip(batch_sizes, context_sizes)]
        
        ax4_twin = ax4.twinx()
        
        bars1 = ax4.bar(range(len(stages)), batch_sizes, alpha=0.7, 
                       color='skyblue', label='Batch Size')
        line1 = ax4_twin.plot(range(len(stages)), effective_tokens, 'ro-', 
                             linewidth=2, markersize=8, label='Effective Tokens')
        
        ax4.set_xlabel('Training Stage')
        ax4.set_ylabel('Batch Size', color='blue')
        ax4_twin.set_ylabel('Effective Tokens (K)', color='red')
        ax4.set_title('Batch Size Adjustment Strategy')
        ax4.set_xticks(range(len(stages)))
        ax4.set_xticklabels(stages)
        
        # Add value labels
        for i, (batch, tokens) in enumerate(zip(batch_sizes, effective_tokens)):
            ax4.text(i, batch + 50, str(batch), ha='center', va='bottom')
            ax4_twin.text(i, tokens + 1000, f'{tokens//1000}K', ha='center', va='bottom', color='red')
        
        plt.tight_layout()
        plt.show()
        
        return fig

# Run YARN training simulation
trainer = YARNTrainingSimulator()
training_results = trainer.run_full_training()
trainer.visualize_training(training_results)

print("\n🎯 YARN Training Insights:")
print("• Two-stage training allows gradual adaptation")
print("• Batch size reduction compensates for memory growth")
print("• Performance maintained across 8x context extension")
print("• Memory scales quadratically, requiring careful optimization")
print("• DeepSeek-V2 achieves 95%+ accuracy up to 128K tokens")

## 💻 Code Understanding with Long Context

### 🔍 Real-world Applications của 128K Context

Demonstrate practical applications của YARN context extension trong code understanding tasks

In [None]:
class LongContextCodeAnalyzer:
    """Analyze code understanding tasks that benefit from long context"""
    
    def __init__(self):
        self.context_limits = {
            'original': 16384,  # 16K tokens
            'yarn_extended': 131072  # 128K tokens
        }
    
    def estimate_tokens(self, text: str) -> int:
        """Rough token estimation (1 token ≈ 4 characters for code)"""
        return len(text) // 4
    
    def generate_large_codebase_example(self, num_files: int = 10) -> Dict[str, str]:
        """Generate example of large codebase"""
        
        codebase = {}
        
        # Main application file
        codebase['main.py'] = '''#!/usr/bin/env python3
"""
Large-scale data processing application
Processes millions of records with ML pipeline
"""

import asyncio
import logging
from typing import List, Dict, Optional
from dataclasses import dataclass
from pathlib import Path

from data_processor import DataProcessor
from ml_pipeline import MLPipeline
from config_manager import ConfigManager
from database import DatabaseManager

@dataclass
class ProcessingConfig:
    batch_size: int = 1000
    max_workers: int = 8
    timeout: int = 300
    retry_count: int = 3

class ApplicationManager:
    def __init__(self, config_path: str):
        self.config = ConfigManager.load(config_path)
        self.db = DatabaseManager(self.config.database_url)
        self.processor = DataProcessor(self.config.processing)
        self.ml_pipeline = MLPipeline(self.config.ml_config)
        self.logger = logging.getLogger(__name__)
    
    async def run_processing_pipeline(self):
        """Main processing pipeline"""
        try:
            # Initialize components
            await self.db.connect()
            await self.processor.initialize()
            
            # Process data in batches
            batch_count = 0
            async for batch in self.db.get_data_batches():
                processed_batch = await self.processor.process_batch(batch)
                ml_results = await self.ml_pipeline.predict(processed_batch)
                await self.db.save_results(ml_results)
                
                batch_count += 1
                self.logger.info(f"Processed batch {batch_count}")
                
            self.logger.info(f"Pipeline completed. Processed {batch_count} batches.")
            
        except Exception as e:
            self.logger.error(f"Pipeline failed: {e}")
            raise
        finally:
            await self.db.disconnect()

if __name__ == "__main__":
    app = ApplicationManager("config.yaml")
    asyncio.run(app.run_processing_pipeline())
'''
        
        # Data processor module
        codebase['data_processor.py'] = '''import pandas as pd
import numpy as np
from typing import List, Dict, Any
import asyncio
from concurrent.futures import ThreadPoolExecutor

class DataProcessor:
    def __init__(self, config):
        self.config = config
        self.executor = ThreadPoolExecutor(max_workers=config.max_workers)
        
    async def initialize(self):
        """Initialize processor"""
        pass
        
    def _clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and preprocess data"""
        # Remove duplicates
        df = df.drop_duplicates()
        
        # Handle missing values
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].mean())
        
        # Normalize text columns
        text_cols = df.select_dtypes(include=['object']).columns
        for col in text_cols:
            df[col] = df[col].str.strip().str.lower()
            
        return df
    
    def _extract_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Extract features for ML"""
        # Feature engineering logic here
        features = df.copy()
        
        # Add derived features
        if 'timestamp' in features.columns:
            features['hour'] = pd.to_datetime(features['timestamp']).dt.hour
            features['day_of_week'] = pd.to_datetime(features['timestamp']).dt.dayofweek
            
        return features
    
    async def process_batch(self, batch_data: List[Dict]) -> pd.DataFrame:
        """Process a batch of data"""
        df = pd.DataFrame(batch_data)
        
        # Run CPU-intensive operations in thread pool
        loop = asyncio.get_event_loop()
        
        # Clean data
        df = await loop.run_in_executor(self.executor, self._clean_data, df)
        
        # Extract features
        df = await loop.run_in_executor(self.executor, self._extract_features, df)
        
        return df
'''

        # Add more files to reach target size
        for i in range(num_files - 2):
            codebase[f'module_{i}.py'] = f'''# Module {i}: Utility functions and classes

import logging
from typing import Any, Dict, List, Optional

class Utility{i}:
    """Utility class for module {i}"""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.logger = logging.getLogger(f"module_{i}")
    
    def process_data_{i}(self, data: List[Dict]) -> List[Dict]:
        """Process data specific to module {i}"""
        processed = []
        for item in data:
            # Complex processing logic here
            result = {{
                'id': item.get('id'),
                'processed_value': item.get('value', 0) * {i + 1},
                'metadata': {{
                    'processor': f'module_{i}',
                    'version': '1.0.{i}'
                }}
            }}
            processed.append(result)
        return processed
    
    def validate_data_{i}(self, data: Dict[str, Any]) -> bool:
        """Validate data for module {i}"""
        required_fields = ['id', 'value', 'timestamp']
        return all(field in data for field in required_fields)

def helper_function_{i}(x: float, y: float) -> float:
    """Helper function for module {i}"""
    return (x + y) * {i + 1} / (x - y + 0.001)

# Constants for module {i}
MODULE_{i}_VERSION = "1.0.{i}"
MODULE_{i}_CONFIG = {{
    'max_batch_size': {100 * (i + 1)},
    'timeout': {30 + i * 5},
    'retry_count': {3 + i}
}}
''' * 10  # Make each module longer
        
        return codebase
    
    def analyze_codebase_scenarios(self, codebase: Dict[str, str]) -> Dict[str, Any]:
        """Analyze different scenarios for code understanding"""
        
        # Calculate total size
        total_text = "\n\n".join(f"# File: {filename}\n{content}" 
                                for filename, content in codebase.items())
        total_tokens = self.estimate_tokens(total_text)
        
        scenarios = {
            'cross_file_analysis': {
                'description': 'Analyze dependencies and interactions across multiple files',
                'requires_full_context': True,
                'example_query': 'How does main.py interact with data_processor.py and what are the data flow patterns?',
                'context_needed': total_tokens
            },
            'refactoring_analysis': {
                'description': 'Identify refactoring opportunities across the entire codebase',
                'requires_full_context': True,
                'example_query': 'What duplicate code patterns exist across modules and how can they be consolidated?',
                'context_needed': total_tokens
            },
            'bug_analysis': {
                'description': 'Trace bugs that span multiple files and modules',
                'requires_full_context': True,
                'example_query': 'If there\'s a data corruption issue, trace the data flow from input to output',
                'context_needed': total_tokens
            },
            'architecture_review': {
                'description': 'Understand overall system architecture and design patterns',
                'requires_full_context': True,
                'example_query': 'Describe the overall architecture and suggest improvements',
                'context_needed': total_tokens
            },
            'single_file_analysis': {
                'description': 'Analyze individual files in isolation',
                'requires_full_context': False,
                'example_query': 'Explain the DataProcessor class implementation',
                'context_needed': self.estimate_tokens(codebase['data_processor.py'])
            }
        }
        
        return {
            'total_tokens': total_tokens,
            'total_files': len(codebase),
            'scenarios': scenarios
        }
    
    def evaluate_context_requirements(self, analysis: Dict[str, Any]) -> Dict[str, bool]:
        """Evaluate which scenarios can be handled with different context limits"""
        
        results = {}
        
        for scenario_name, scenario in analysis['scenarios'].items():
            context_needed = scenario['context_needed']
            
            results[scenario_name] = {
                'can_handle_original': context_needed <= self.context_limits['original'],
                'can_handle_yarn': context_needed <= self.context_limits['yarn_extended'],
                'context_needed': context_needed,
                'requires_full_context': scenario['requires_full_context']
            }
        
        return results

# Demonstrate long context code analysis
print("💻 Long Context Code Understanding Analysis:")
print("=" * 50)

analyzer = LongContextCodeAnalyzer()

# Generate large codebase
print("🏗️ Generating large codebase example...")
large_codebase = analyzer.generate_large_codebase_example(num_files=15)

# Analyze scenarios
analysis = analyzer.analyze_codebase_scenarios(large_codebase)
context_eval = analyzer.evaluate_context_requirements(analysis)

print(f"\n📊 Codebase Statistics:")
print(f"   Total files: {analysis['total_files']}")
print(f"   Total tokens: {analysis['total_tokens']:,}")
print(f"   Size vs 16K limit: {analysis['total_tokens'] / analyzer.context_limits['original']:.1f}x")
print(f"   Size vs 128K limit: {analysis['total_tokens'] / analyzer.context_limits['yarn_extended']:.1f}x")

print(f"\n🔍 Scenario Analysis:")
print(f"{'Scenario':<25} {'Tokens':<8} {'16K OK':<8} {'128K OK':<8} {'Full Context':<12}")
print("-" * 70)

for scenario_name, eval_result in context_eval.items():
    tokens = eval_result['context_needed']
    can_16k = "✅" if eval_result['can_handle_original'] else "❌"
    can_128k = "✅" if eval_result['can_handle_yarn'] else "❌"
    full_context = "Yes" if eval_result['requires_full_context'] else "No"
    
    print(f"{scenario_name:<25} {tokens:<8,} {can_16k:<8} {can_128k:<8} {full_context:<12}")

# Visualize context requirements
def visualize_context_benefits():
    """Visualize benefits of YARN context extension"""
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # 1. Context limits comparison
    ax1 = axes[0]
    
    scenarios = list(context_eval.keys())
    tokens_needed = [context_eval[s]['context_needed'] for s in scenarios]
    
    bars = ax1.barh(scenarios, tokens_needed, alpha=0.7)
    
    # Add context limit lines
    ax1.axvline(x=analyzer.context_limits['original'], color='red', 
               linestyle='--', linewidth=2, label='16K Limit (Original)')
    ax1.axvline(x=analyzer.context_limits['yarn_extended'], color='green', 
               linestyle='--', linewidth=2, label='128K Limit (YARN)')
    
    # Color bars based on feasibility
    for i, (bar, scenario) in enumerate(zip(bars, scenarios)):
        if context_eval[scenario]['can_handle_original']:
            bar.set_color('lightgreen')
        elif context_eval[scenario]['can_handle_yarn']:
            bar.set_color('orange')
        else:
            bar.set_color('lightcoral')
    
    ax1.set_xlabel('Tokens Required')
    ax1.set_title('Context Requirements by Scenario')
    ax1.legend()
    ax1.set_xscale('log')
    ax1.grid(True, alpha=0.3)
    
    # 2. Capability comparison
    ax2 = axes[1]
    
    capabilities = ['16K Context', '128K Context']
    scenario_counts = {
        '16K Context': sum(1 for s in context_eval.values() if s['can_handle_original']),
        '128K Context': sum(1 for s in context_eval.values() if s['can_handle_yarn'])
    }
    
    bars = ax2.bar(capabilities, [scenario_counts[cap] for cap in capabilities], 
                   color=['red', 'green'], alpha=0.7)
    
    ax2.set_ylabel('Scenarios Supported')
    ax2.set_title('Scenarios Supported by Context Length')
    ax2.set_ylim(0, len(scenarios))
    
    # Add value labels
    for bar, cap in zip(bars, capabilities):
        height = bar.get_height()
        total = len(scenarios)
        percentage = height / total * 100
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                f'{int(height)}/{total}\n({percentage:.0f}%)', 
                ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

visualize_context_benefits()

print(f"\n🚀 YARN Context Extension Benefits:")
original_capable = sum(1 for s in context_eval.values() if s['can_handle_original'])
yarn_capable = sum(1 for s in context_eval.values() if s['can_handle_yarn'])
total_scenarios = len(context_eval)

print(f"• 16K context handles: {original_capable}/{total_scenarios} scenarios ({original_capable/total_scenarios:.0%})")
print(f"• 128K context handles: {yarn_capable}/{total_scenarios} scenarios ({yarn_capable/total_scenarios:.0%})")
print(f"• YARN enables {yarn_capable - original_capable} additional complex scenarios")
print(f"• Critical for: cross-file analysis, refactoring, architecture review")
print(f"• 8x context extension unlocks repository-level understanding")

## 🏁 Summary & Key Takeaways

### 📋 YARN Deep Dive Summary

1. **RoPE Foundation**: Rotary embeddings enable relative position encoding
2. **YARN Innovation**: Frequency-selective scaling for context extension
3. **DeepSeek-V2 Config**: s=40, α=1, β=32 for 8x extension (16K→128K)
4. **Training Strategy**: Two-stage gradual extension with batch size adjustment
5. **Performance**: >95% accuracy on "Needle in Haystack" across all context lengths
6. **Applications**: Repository-level code understanding, cross-file analysis

### 🔬 Research Impact

YARN enables practical long-context applications in code intelligence:
- **Cross-file dependency analysis**
- **Large-scale refactoring**
- **Architecture-level understanding**
- **End-to-end bug tracing**

In [None]:
# Final YARN implementation summary
def create_yarn_summary_dashboard():
    """Create comprehensive YARN summary dashboard"""
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    
    # 1. Context extension progression
    ax1 = axes[0, 0]
    models = ['GPT-3', 'Original\nRoPE', 'DeepSeek-V2\n(YARN)']
    context_lengths = [2, 16, 128]  # K tokens
    
    bars = ax1.bar(models, context_lengths, color=['gray', 'blue', 'green'], alpha=0.7)
    ax1.set_ylabel('Context Length (K tokens)')
    ax1.set_title('Context Length Evolution')
    ax1.set_yscale('log')
    
    for bar, length in zip(bars, context_lengths):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height * 1.1,
                f'{length}K', ha='center', va='bottom', fontweight='bold')
    
    # 2. YARN frequency modification visualization
    ax2 = axes[0, 1]
    
    dim_ratios = np.linspace(0, 1, 100)
    alpha, beta = 1.0, 32.0
    scale = 40.0
    
    # YARN scaling function
    scaling_factors = np.ones_like(dim_ratios)
    
    # Low frequency: full scaling
    low_freq_mask = dim_ratios < alpha
    scaling_factors[low_freq_mask] = 1.0 / scale
    
    # Mid frequency: interpolation
    mid_freq_mask = (dim_ratios >= alpha) & (dim_ratios < beta)
    for i, ratio in enumerate(dim_ratios):
        if mid_freq_mask[i]:
            interp_factor = (ratio - alpha) / (beta - alpha)
            scaling_factors[i] = (1.0 / scale) * (1 - interp_factor) + 1.0 * interp_factor
    
    ax2.plot(dim_ratios * 100, scaling_factors, linewidth=3, color='red')
    ax2.axvline(alpha * 100, color='green', linestyle='--', alpha=0.7, label=f'α={alpha}')
    ax2.axvline(beta * 100, color='orange', linestyle='--', alpha=0.7, label=f'β={beta}')
    ax2.set_xlabel('Frequency Dimension (%)')
    ax2.set_ylabel('Scaling Factor')
    ax2.set_title('YARN Frequency Scaling (s=40)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Memory scaling comparison
    ax3 = axes[0, 2]
    
    seq_lengths = [1, 2, 4, 8, 16, 32, 64, 128]  # K tokens
    memory_quadratic = [s**2 for s in seq_lengths]  # O(n²) attention
    memory_linear = [s for s in seq_lengths]        # Hypothetical O(n)
    
    ax3.plot(seq_lengths, memory_quadratic, 'r-o', linewidth=2, label='Standard Attention O(n²)')
    ax3.plot(seq_lengths, memory_linear, 'g--s', linewidth=2, label='Linear Attention O(n)')
    
    ax3.set_xlabel('Sequence Length (K tokens)')
    ax3.set_ylabel('Relative Memory Usage')
    ax3.set_title('Memory Scaling Challenges')
    ax3.set_yscale('log')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Training stages
    ax4 = axes[1, 0]
    
    stages = ['Pre-training\n(16K)', 'Stage 1\n(32K)', 'Stage 2\n(128K)']
    context_lens = [16, 32, 128]
    batch_sizes = [2048, 1152, 288]  # Adjusted for memory
    
    ax4_twin = ax4.twinx()
    
    bars = ax4.bar(stages, context_lens, alpha=0.7, color='skyblue', label='Context Length')
    line = ax4_twin.plot(stages, batch_sizes, 'ro-', linewidth=2, markersize=8, label='Batch Size')
    
    ax4.set_ylabel('Context Length (K)', color='blue')
    ax4_twin.set_ylabel('Batch Size', color='red')
    ax4.set_title('YARN Training Strategy')
    
    # 5. Performance on different tasks
    ax5 = axes[1, 1]
    
    tasks = ['Code\nGeneration', 'Long Context\nQA', 'Repository\nAnalysis', 'Bug\nTracing']
    original_perf = [90, 60, 30, 25]  # Hypothetical performance without YARN
    yarn_perf = [90, 95, 90, 85]     # With YARN
    
    x = np.arange(len(tasks))
    width = 0.35
    
    bars1 = ax5.bar(x - width/2, original_perf, width, label='16K Context', alpha=0.7, color='red')
    bars2 = ax5.bar(x + width/2, yarn_perf, width, label='128K Context (YARN)', alpha=0.7, color='green')
    
    ax5.set_ylabel('Performance (%)')
    ax5.set_title('Task Performance: 16K vs 128K Context')
    ax5.set_xticks(x)
    ax5.set_xticklabels(tasks)
    ax5.legend()
    ax5.set_ylim(0, 100)
    
    # 6. Key achievements
    ax6 = axes[1, 2]
    ax6.axis('off')
    
    achievements = [
        '🎯 8x Context Extension (16K → 128K)',
        '📊 >95% Needle-in-Haystack Accuracy',
        '🧵 YARN Frequency-Selective Scaling',
        '🏋️ Two-Stage Training Strategy',
        '💻 Repository-Level Understanding',
        '🔍 Cross-File Dependency Analysis',
        '⚡ Efficient Implementation',
        '🚀 State-of-the-Art Results'
    ]
    
    ax6.text(0.05, 0.95, 'YARN Key Achievements:', fontsize=14, fontweight='bold', 
             transform=ax6.transAxes)
    
    for i, achievement in enumerate(achievements):
        ax6.text(0.05, 0.85 - i*0.1, achievement, fontsize=11, 
                transform=ax6.transAxes)
    
    plt.suptitle('YARN Context Extension: Complete Technical Deep Dive', 
                 fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.show()
    
    # Technical specifications
    print("🧵 YARN Technical Specifications:")
    print("=" * 50)
    print(f"📊 Scale Factor (s): 40.0")
    print(f"📊 Alpha (α): 1.0 (low frequency threshold)")
    print(f"📊 Beta (β): 32.0 (high frequency threshold)")
    print(f"📊 Context Extension: 16K → 128K (8x)")
    print(f"📊 Training: 2-stage (32K → 128K)")
    print(f"📊 Performance: >95% accuracy maintained")
    
    print("\n💡 Implementation Insights:")
    print("• Low frequencies get full scaling (preserve global structure)")
    print("• High frequencies unchanged (preserve fine details)")
    print("• Interpolation in middle frequencies (smooth transition)")
    print("• Attention scaling prevents distribution collapse")
    print("• Two-stage training enables gradual adaptation")

create_yarn_summary_dashboard()

print("\n🎉 YARN Context Extension Deep Dive Complete!")
print("\n📚 Further Reading:")
print("• YaRN: Efficient Context Window Extension (Peng et al., 2023)")
print("• RoFormer: Enhanced Transformer with Rotary Position Embedding")
print("• DeepSeek-V2: A Strong, Economical, and Efficient MoE LLM")
print("• Long Range Arena: A Benchmark for Efficient Transformers")
print("\n✨ Next: Explore Fill-In-the-Middle Training! ✨")