# 07 - Inference and Decoding Strategies

This notebook covers various strategies for generating text from language models.

## Topics Covered:
- Greedy decoding
- Beam search
- Top-k sampling
- Top-p (nucleus) sampling
- Temperature scaling
- Repetition penalties

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Optional
from collections import defaultdict
import heapq

np.random.seed(42)

## 1. Decoding Strategies Implementation

In [None]:
class DecodingStrategies:
    """Implementation of various decoding strategies."""
    
    def __init__(self, vocab_size: int = 1000):
        self.vocab_size = vocab_size
        self.eos_token = 0
    
    def softmax(self, logits: np.ndarray, temperature: float = 1.0) -> np.ndarray:
        """Apply softmax with temperature scaling."""
        scaled_logits = logits / temperature
        exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
        return exp_logits / np.sum(exp_logits)
    
    def greedy_decode(self, logits: np.ndarray) -> int:
        """Greedy decoding - select most probable token."""
        return np.argmax(logits)
    
    def top_k_sampling(self, logits: np.ndarray, k: int, temperature: float = 1.0) -> int:
        """Top-k sampling."""
        # Get top-k indices
        top_k_indices = np.argpartition(logits, -k)[-k:]
        top_k_logits = logits[top_k_indices]
        
        # Apply temperature and softmax
        probs = self.softmax(top_k_logits, temperature)
        
        # Sample from top-k
        sampled_idx = np.random.choice(len(top_k_indices), p=probs)
        return top_k_indices[sampled_idx]
    
    def top_p_sampling(self, logits: np.ndarray, p: float, temperature: float = 1.0) -> int:
        """Top-p (nucleus) sampling."""
        # Apply temperature and get probabilities
        probs = self.softmax(logits, temperature)
        
        # Sort probabilities in descending order
        sorted_indices = np.argsort(probs)[::-1]
        sorted_probs = probs[sorted_indices]
        
        # Find nucleus (top-p)
        cumsum_probs = np.cumsum(sorted_probs)
        nucleus_size = np.searchsorted(cumsum_probs, p) + 1
        
        # Sample from nucleus
        nucleus_indices = sorted_indices[:nucleus_size]
        nucleus_probs = sorted_probs[:nucleus_size]
        nucleus_probs = nucleus_probs / np.sum(nucleus_probs)  # Renormalize
        
        sampled_idx = np.random.choice(len(nucleus_indices), p=nucleus_probs)
        return nucleus_indices[sampled_idx]
    
    def beam_search(self, get_logits_fn, start_token: int, max_length: int, 
                   beam_width: int, length_penalty: float = 1.0) -> List[Tuple[List[int], float]]:
        """Beam search decoding."""
        # Initialize beams: (sequence, log_prob)
        beams = [([start_token], 0.0)]
        completed_beams = []
        
        for step in range(max_length):
            candidates = []
            
            for sequence, log_prob in beams:
                if sequence[-1] == self.eos_token:
                    completed_beams.append((sequence, log_prob))
                    continue
                
                # Get logits for next token
                logits = get_logits_fn(sequence)
                probs = self.softmax(logits)
                
                # Get top beam_width candidates
                top_indices = np.argpartition(probs, -beam_width)[-beam_width:]
                
                for token_id in top_indices:
                    new_sequence = sequence + [token_id]
                    new_log_prob = log_prob + np.log(probs[token_id] + 1e-10)
                    candidates.append((new_sequence, new_log_prob))
            
            # Select top beam_width candidates
            candidates.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty), reverse=True)
            beams = candidates[:beam_width]
            
            if not beams:
                break
        
        # Add remaining beams to completed
        completed_beams.extend(beams)
        
        # Sort by score
        completed_beams.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty), reverse=True)
        
        return completed_beams
    
    def apply_repetition_penalty(self, logits: np.ndarray, generated_tokens: List[int], 
                               penalty: float = 1.2) -> np.ndarray:
        """Apply repetition penalty to logits."""
        penalized_logits = logits.copy()
        
        for token in set(generated_tokens):
            if token < len(penalized_logits):
                if penalized_logits[token] > 0:
                    penalized_logits[token] /= penalty
                else:
                    penalized_logits[token] *= penalty
        
        return penalized_logits

def demonstrate_decoding_strategies():
    """Demonstrate different decoding strategies."""
    
    # Create a simple mock language model
    vocab_size = 20
    decoder = DecodingStrategies(vocab_size)
    
    # Mock function to get logits (simplified)
    def get_logits(sequence):
        # Simple pattern: higher probability for tokens 1-5
        logits = np.random.randn(vocab_size) * 0.5
        logits[1:6] += 2.0  # Boost certain tokens
        
        # Add some context dependency
        if len(sequence) > 1:
            last_token = sequence[-1]
            if last_token < vocab_size:
                logits[last_token] -= 1.0  # Reduce repetition
        
        return logits
    
    # Generate sample logits
    sample_logits = get_logits([1])
    
    # Test different strategies
    print("Decoding Strategy Comparison:")
    print(f"Sample logits shape: {sample_logits.shape}")
    
    # Greedy decoding
    greedy_token = decoder.greedy_decode(sample_logits)
    print(f"\nGreedy decoding: token {greedy_token}")
    
    # Top-k sampling
    top_k_tokens = [decoder.top_k_sampling(sample_logits, k=5) for _ in range(5)]
    print(f"Top-k sampling (k=5): {top_k_tokens}")
    
    # Top-p sampling
    top_p_tokens = [decoder.top_p_sampling(sample_logits, p=0.9) for _ in range(5)]
    print(f"Top-p sampling (p=0.9): {top_p_tokens}")
    
    # Beam search
    beam_results = decoder.beam_search(get_logits, start_token=1, max_length=5, beam_width=3)
    print(f"\nBeam search results:")
    for i, (sequence, score) in enumerate(beam_results[:3]):
        print(f"  Beam {i+1}: {sequence} (score: {score:.3f})")
    
    # Visualize decoding strategies
    plt.figure(figsize=(15, 12))
    
    # Original probability distribution
    plt.subplot(3, 3, 1)
    probs = decoder.softmax(sample_logits)
    plt.bar(range(vocab_size), probs, alpha=0.7)
    plt.title('Original Probability Distribution')
    plt.xlabel('Token ID')
    plt.ylabel('Probability')
    
    # Temperature effects
    plt.subplot(3, 3, 2)
    temperatures = [0.5, 1.0, 2.0]
    for temp in temperatures:
        temp_probs = decoder.softmax(sample_logits, temp)
        plt.plot(temp_probs, label=f'T={temp}', alpha=0.7)
    
    plt.title('Temperature Effects')
    plt.xlabel('Token ID')
    plt.ylabel('Probability')
    plt.legend()
    
    # Top-k sampling visualization
    plt.subplot(3, 3, 3)
    k_values = [3, 5, 10]
    
    for k in k_values:
        top_k_indices = np.argpartition(sample_logits, -k)[-k:]
        top_k_probs = np.zeros(vocab_size)
        top_k_probs[top_k_indices] = decoder.softmax(sample_logits[top_k_indices])
        
        plt.bar(range(vocab_size), top_k_probs, alpha=0.5, label=f'k={k}')
    
    plt.title('Top-k Sampling')
    plt.xlabel('Token ID')
    plt.ylabel('Probability')
    plt.legend()
    
    # Top-p sampling visualization
    plt.subplot(3, 3, 4)
    p_values = [0.7, 0.9, 0.95]
    
    for p in p_values:
        probs = decoder.softmax(sample_logits)
        sorted_indices = np.argsort(probs)[::-1]
        sorted_probs = probs[sorted_indices]
        cumsum_probs = np.cumsum(sorted_probs)
        nucleus_size = np.searchsorted(cumsum_probs, p) + 1
        
        nucleus_probs = np.zeros(vocab_size)
        nucleus_indices = sorted_indices[:nucleus_size]
        nucleus_probs[nucleus_indices] = sorted_probs[:nucleus_size]
        nucleus_probs = nucleus_probs / np.sum(nucleus_probs)
        
        plt.bar(range(vocab_size), nucleus_probs, alpha=0.5, label=f'p={p}')
    
    plt.title('Top-p (Nucleus) Sampling')
    plt.xlabel('Token ID')
    plt.ylabel('Probability')
    plt.legend()
    
    # Repetition penalty effect
    plt.subplot(3, 3, 5)
    generated_tokens = [2, 3, 2, 4]  # Some repeated tokens
    
    original_probs = decoder.softmax(sample_logits)
    penalized_logits = decoder.apply_repetition_penalty(sample_logits, generated_tokens, penalty=1.5)
    penalized_probs = decoder.softmax(penalized_logits)
    
    x = np.arange(vocab_size)
    width = 0.35
    
    plt.bar(x - width/2, original_probs, width, label='Original', alpha=0.7)
    plt.bar(x + width/2, penalized_probs, width, label='With Penalty', alpha=0.7)
    
    # Highlight repeated tokens
    for token in set(generated_tokens):
        plt.axvline(x=token, color='red', linestyle='--', alpha=0.5)
    
    plt.title('Repetition Penalty Effect')
    plt.xlabel('Token ID')
    plt.ylabel('Probability')
    plt.legend()
    
    # Beam search tree visualization (simplified)
    plt.subplot(3, 3, 6)
    
    # Simulate beam search steps
    beam_data = {
        'Step 0': [1.0],
        'Step 1': [0.8, 0.6, 0.4],
        'Step 2': [0.7, 0.5, 0.4, 0.3, 0.2],
        'Step 3': [0.6, 0.4, 0.3]
    }
    
    for i, (step, scores) in enumerate(beam_data.items()):
        y_positions = np.linspace(-1, 1, len(scores))
        plt.scatter([i] * len(scores), y_positions, s=[s*100 for s in scores], alpha=0.7)
        
        # Connect to previous step (simplified)
        if i > 0:
            prev_step, prev_scores = list(beam_data.items())[i-1]
            prev_y = np.linspace(-1, 1, len(prev_scores))
            for j, y in enumerate(y_positions[:len(prev_y)]):
                if j < len(prev_y):
                    plt.plot([i-1, i], [prev_y[j], y], 'k-', alpha=0.3)
    
    plt.title('Beam Search Tree')
    plt.xlabel('Step')
    plt.ylabel('Beam Position')
    plt.xticks(range(len(beam_data)), beam_data.keys())
    
    # Strategy comparison metrics
    plt.subplot(3, 3, 7)
    
    # Simulate diversity and quality metrics
    strategies = ['Greedy', 'Top-k', 'Top-p', 'Beam Search']
    diversity = [0.1, 0.7, 0.8, 0.4]  # Higher = more diverse
    quality = [0.9, 0.7, 0.6, 0.8]   # Higher = better quality
    
    x = np.arange(len(strategies))
    width = 0.35
    
    plt.bar(x - width/2, diversity, width, label='Diversity', alpha=0.7)
    plt.bar(x + width/2, quality, width, label='Quality', alpha=0.7)
    
    plt.title('Strategy Trade-offs')
    plt.xlabel('Strategy')
    plt.ylabel('Score')
    plt.xticks(x, strategies, rotation=45)
    plt.legend()
    
    # Computational cost comparison
    plt.subplot(3, 3, 8)
    
    # Relative computational costs
    costs = [1, 2, 2.5, 8]  # Relative to greedy
    
    bars = plt.bar(strategies, costs, alpha=0.7)
    plt.title('Computational Cost')
    plt.xlabel('Strategy')
    plt.ylabel('Relative Cost')
    plt.xticks(rotation=45)
    
    # Add value labels
    for bar, cost in zip(bars, costs):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
                f'{cost}x', ha='center')
    
    # Parameter sensitivity analysis
    plt.subplot(3, 3, 9)
    
    # Show how different parameters affect output diversity
    param_ranges = {
        'Temperature': np.linspace(0.1, 2.0, 10),
        'Top-k': np.arange(1, 11),
        'Top-p': np.linspace(0.1, 1.0, 10)
    }
    
    # Simulate diversity scores for different parameter values
    for param_name, param_values in param_ranges.items():
        if param_name == 'Temperature':
            diversity_scores = 1 - np.exp(-param_values)  # Increases with temperature
        elif param_name == 'Top-k':
            diversity_scores = np.log(param_values + 1) / np.log(11)  # Logarithmic increase
        else:  # Top-p
            diversity_scores = param_values  # Linear increase
        
        plt.plot(param_values, diversity_scores, 'o-', label=param_name, alpha=0.7)
    
    plt.title('Parameter Sensitivity')
    plt.xlabel('Parameter Value')
    plt.ylabel('Output Diversity')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nDecoding Strategy Characteristics:")
    
    print("\nGreedy Decoding:")
    print("  + Fast and deterministic")
    print("  + High quality for well-trained models")
    print("  - No diversity, can get stuck in loops")
    
    print("\nBeam Search:")
    print("  + Finds high-probability sequences")
    print("  + Good for tasks requiring accuracy")
    print("  - Computationally expensive")
    print("  - Can produce generic outputs")
    
    print("\nTop-k Sampling:")
    print("  + Good balance of quality and diversity")
    print("  + Controllable via k parameter")
    print("  - Fixed vocabulary size regardless of distribution")
    
    print("\nTop-p Sampling:")
    print("  + Adaptive vocabulary size")
    print("  + Good for creative tasks")
    print("  + Handles varying confidence levels")
    print("  - Can be unstable with poor models")

demonstrate_decoding_strategies()

## 2. Advanced Decoding Techniques

In [None]:
class AdvancedDecoding:
    """Advanced decoding techniques."""
    
    def __init__(self):
        self.eos_token = 0
    
    def diverse_beam_search(self, get_logits_fn, start_token: int, max_length: int,
                          beam_width: int, num_groups: int, diversity_penalty: float = 0.5):
        """Diverse beam search for more varied outputs."""
        group_size = beam_width // num_groups
        all_beams = []
        
        for group in range(num_groups):
            # Initialize group beams
            beams = [([start_token], 0.0)]
            
            for step in range(max_length):
                candidates = []
                
                for sequence, log_prob in beams:
                    if sequence[-1] == self.eos_token:
                        candidates.append((sequence, log_prob))
                        continue
                    
                    logits = get_logits_fn(sequence)
                    
                    # Apply diversity penalty based on other groups
                    if step > 0 and all_beams:
                        for other_group_beams in all_beams:
                            for other_seq, _ in other_group_beams:
                                if step < len(other_seq):
                                    other_token = other_seq[step]
                                    if other_token < len(logits):
                                        logits[other_token] -= diversity_penalty
                    
                    probs = self._softmax(logits)
                    top_indices = np.argpartition(probs, -group_size)[-group_size:]
                    
                    for token_id in top_indices:
                        new_sequence = sequence + [token_id]
                        new_log_prob = log_prob + np.log(probs[token_id] + 1e-10)
                        candidates.append((new_sequence, new_log_prob))
                
                candidates.sort(key=lambda x: x[1], reverse=True)
                beams = candidates[:group_size]
                
                if not beams:
                    break
            
            all_beams.append(beams)
        
        # Flatten and return all beams
        result = []
        for group_beams in all_beams:
            result.extend(group_beams)
        
        result.sort(key=lambda x: x[1], reverse=True)
        return result
    
    def contrastive_search(self, get_logits_fn, start_token: int, max_length: int,
                         alpha: float = 0.6, k: int = 4):
        """Contrastive search for coherent and diverse generation."""
        sequence = [start_token]
        
        for step in range(max_length):
            if sequence[-1] == self.eos_token:
                break
            
            logits = get_logits_fn(sequence)
            probs = self._softmax(logits)
            
            # Get top-k candidates
            top_k_indices = np.argpartition(probs, -k)[-k:]
            
            best_score = float('-inf')
            best_token = top_k_indices[0]
            
            for token_id in top_k_indices:
                # Model confidence
                model_prob = probs[token_id]
                
                # Degeneration penalty (simplified)
                degeneration_penalty = 0
                if len(sequence) > 1:
                    # Penalize repetition
                    recent_tokens = sequence[-min(5, len(sequence)):]
                    if token_id in recent_tokens:
                        degeneration_penalty = 0.5
                
                # Combined score
                score = alpha * np.log(model_prob) - (1 - alpha) * degeneration_penalty
                
                if score > best_score:
                    best_score = score
                    best_token = token_id
            
            sequence.append(best_token)
        
        return sequence
    
    def _softmax(self, logits: np.ndarray) -> np.ndarray:
        """Helper softmax function."""
        exp_logits = np.exp(logits - np.max(logits))
        return exp_logits / np.sum(exp_logits)

def demonstrate_advanced_decoding():
    """Demonstrate advanced decoding techniques."""
    
    vocab_size = 15
    decoder = AdvancedDecoding()
    
    # Mock logits function with some patterns
    def get_logits(sequence):
        logits = np.random.randn(vocab_size) * 0.5
        
        # Create some patterns
        if len(sequence) % 2 == 0:
            logits[1:4] += 1.5  # Boost tokens 1-3 on even steps
        else:
            logits[5:8] += 1.5  # Boost tokens 5-7 on odd steps
        
        # Reduce probability of recent tokens
        for token in sequence[-3:]:
            if token < vocab_size:
                logits[token] -= 0.5
        
        return logits
    
    # Test advanced techniques
    print("Advanced Decoding Techniques:")
    
    # Diverse beam search
    diverse_results = decoder.diverse_beam_search(
        get_logits, start_token=1, max_length=8, 
        beam_width=6, num_groups=2, diversity_penalty=0.5
    )
    
    print("\nDiverse Beam Search Results:")
    for i, (sequence, score) in enumerate(diverse_results[:4]):
        print(f"  {i+1}: {sequence} (score: {score:.3f})")
    
    # Contrastive search
    contrastive_result = decoder.contrastive_search(
        get_logits, start_token=1, max_length=8, alpha=0.6, k=4
    )
    
    print(f"\nContrastive Search Result: {contrastive_result}")
    
    # Compare diversity metrics
    def calculate_diversity(sequences):
        """Calculate diversity metrics for a set of sequences."""
        if not sequences:
            return 0, 0
        
        # Unique tokens
        all_tokens = []
        for seq, _ in sequences:
            all_tokens.extend(seq)
        
        unique_tokens = len(set(all_tokens))
        total_tokens = len(all_tokens)
        
        # Sequence diversity (Jaccard distance)
        if len(sequences) < 2:
            return unique_tokens / total_tokens, 0
        
        diversity_sum = 0
        count = 0
        
        for i in range(len(sequences)):
            for j in range(i + 1, len(sequences)):
                seq1_set = set(sequences[i][0])
                seq2_set = set(sequences[j][0])
                
                intersection = len(seq1_set.intersection(seq2_set))
                union = len(seq1_set.union(seq2_set))
                
                if union > 0:
                    diversity_sum += 1 - (intersection / union)
                    count += 1
        
        avg_diversity = diversity_sum / count if count > 0 else 0
        token_diversity = unique_tokens / total_tokens
        
        return token_diversity, avg_diversity
    
    # Calculate metrics
    token_div, seq_div = calculate_diversity(diverse_results[:4])
    
    print(f"\nDiversity Metrics:")
    print(f"  Token diversity: {token_div:.3f}")
    print(f"  Sequence diversity: {seq_div:.3f}")
    
    # Visualize advanced techniques
    plt.figure(figsize=(15, 10))
    
    # Diverse beam search visualization
    plt.subplot(2, 3, 1)
    
    # Show token distribution across diverse beams
    token_counts = defaultdict(int)
    for sequence, _ in diverse_results[:4]:
        for token in sequence:
            token_counts[token] += 1
    
    tokens = list(token_counts.keys())
    counts = list(token_counts.values())
    
    plt.bar(tokens, counts, alpha=0.7)
    plt.title('Diverse Beam Search\nToken Distribution')
    plt.xlabel('Token ID')
    plt.ylabel('Frequency')
    
    # Contrastive search token distribution
    plt.subplot(2, 3, 2)
    
    contrastive_counts = defaultdict(int)
    for token in contrastive_result:
        contrastive_counts[token] += 1
    
    c_tokens = list(contrastive_counts.keys())
    c_counts = list(contrastive_counts.values())
    
    plt.bar(c_tokens, c_counts, alpha=0.7, color='orange')
    plt.title('Contrastive Search\nToken Distribution')
    plt.xlabel('Token ID')
    plt.ylabel('Frequency')
    
    # Diversity comparison
    plt.subplot(2, 3, 3)
    
    techniques = ['Standard\nBeam', 'Diverse\nBeam', 'Contrastive\nSearch']
    
    # Simulate diversity scores
    diversity_scores = [0.3, 0.7, 0.6]
    quality_scores = [0.8, 0.6, 0.7]
    
    x = np.arange(len(techniques))
    width = 0.35
    
    plt.bar(x - width/2, diversity_scores, width, label='Diversity', alpha=0.7)
    plt.bar(x + width/2, quality_scores, width, label='Quality', alpha=0.7)
    
    plt.title('Technique Comparison')
    plt.xlabel('Technique')
    plt.ylabel('Score')
    plt.xticks(x, techniques)
    plt.legend()
    
    # Sequence length analysis
    plt.subplot(2, 3, 4)
    
    diverse_lengths = [len(seq) for seq, _ in diverse_results[:6]]
    
    plt.hist(diverse_lengths, bins=5, alpha=0.7, label='Diverse Beam')
    plt.axvline(len(contrastive_result), color='orange', linestyle='--', 
               label='Contrastive', linewidth=2)
    
    plt.title('Sequence Length Distribution')
    plt.xlabel('Sequence Length')
    plt.ylabel('Frequency')
    plt.legend()
    
    # Parameter sensitivity for contrastive search
    plt.subplot(2, 3, 5)
    
    alpha_values = np.linspace(0.1, 0.9, 9)
    
    # Simulate how alpha affects diversity vs quality
    diversity_trend = alpha_values  # Higher alpha = more diverse
    quality_trend = 1 - alpha_values  # Lower alpha = higher quality
    
    plt.plot(alpha_values, diversity_trend, 'o-', label='Diversity', alpha=0.7)
    plt.plot(alpha_values, quality_trend, 's-', label='Quality', alpha=0.7)
    
    plt.title('Contrastive Search\nAlpha Parameter Effect')
    plt.xlabel('Alpha Value')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Computational complexity comparison
    plt.subplot(2, 3, 6)
    
    complexities = {
        'Greedy': 1,
        'Beam Search': 5,
        'Diverse Beam': 7,
        'Contrastive': 3
    }
    
    plt.bar(complexities.keys(), complexities.values(), alpha=0.7)
    plt.title('Computational Complexity')
    plt.xlabel('Method')
    plt.ylabel('Relative Cost')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print("\nAdvanced Decoding Insights:")
    
    print("\nDiverse Beam Search:")
    print("  + Generates more varied outputs")
    print("  + Good for creative applications")
    print("  - More computationally expensive")
    print("  - May sacrifice some quality for diversity")
    
    print("\nContrastive Search:")
    print("  + Balances coherence and diversity")
    print("  + Reduces repetition effectively")
    print("  + Computationally efficient")
    print("  - Requires tuning of alpha parameter")
    
    print("\nKey Considerations:")
    print("  - Task requirements (creativity vs accuracy)")
    print("  - Computational budget")
    print("  - Model quality and training")
    print("  - User preferences and application context")

demonstrate_advanced_decoding()