# 06 - Training Large Language Models

This notebook covers practical aspects of training large language models.

## Topics Covered:
- Pretraining objectives
- Next-token prediction
- Training data collection and preprocessing
- Batch processing and token batching
- Mixed-precision training
- Distributed training strategies

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional
import re
from collections import defaultdict

np.random.seed(42)

## 1. Pretraining Objectives

In [None]:
class PretrainingObjectives:
    """Different pretraining objectives for language models."""
    
    @staticmethod
    def next_token_prediction_loss(logits: np.ndarray, targets: np.ndarray) -> float:
        """Compute next-token prediction loss (cross-entropy)."""
        # Apply softmax to logits
        exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
        probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
        
        # Compute cross-entropy loss
        batch_size, seq_len = targets.shape
        loss = 0
        
        for i in range(batch_size):
            for j in range(seq_len):
                target_prob = probs[i, j, targets[i, j]]
                loss += -np.log(target_prob + 1e-10)
        
        return loss / (batch_size * seq_len)
    
    @staticmethod
    def masked_language_modeling_loss(logits: np.ndarray, targets: np.ndarray, 
                                    mask: np.ndarray) -> float:
        """Compute MLM loss (only on masked positions)."""
        exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
        probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
        
        loss = 0
        masked_count = 0
        
        batch_size, seq_len = targets.shape
        for i in range(batch_size):
            for j in range(seq_len):
                if mask[i, j]:  # Only compute loss on masked positions
                    target_prob = probs[i, j, targets[i, j]]
                    loss += -np.log(target_prob + 1e-10)
                    masked_count += 1
        
        return loss / masked_count if masked_count > 0 else 0

def demonstrate_pretraining_objectives():
    """Compare different pretraining objectives."""
    
    # Sample data
    batch_size, seq_len, vocab_size = 2, 8, 100
    
    # Random logits and targets
    logits = np.random.randn(batch_size, seq_len, vocab_size)
    targets = np.random.randint(0, vocab_size, (batch_size, seq_len))
    
    # Create mask for MLM (15% of tokens masked)
    mask = np.random.random((batch_size, seq_len)) < 0.15
    
    # Compute losses
    ntp_loss = PretrainingObjectives.next_token_prediction_loss(logits, targets)
    mlm_loss = PretrainingObjectives.masked_language_modeling_loss(logits, targets, mask)
    
    print("Pretraining Objectives Comparison:")
    print(f"Next-Token Prediction Loss: {ntp_loss:.4f}")
    print(f"Masked Language Modeling Loss: {mlm_loss:.4f}")
    
    # Visualize masking pattern
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.imshow(mask, cmap='RdYlBu', aspect='auto')
    plt.title('MLM Masking Pattern')
    plt.xlabel('Sequence Position')
    plt.ylabel('Batch Item')
    plt.colorbar(label='Masked (1) / Not Masked (0)')
    
    # Show loss computation differences
    plt.subplot(1, 2, 2)
    
    objectives = ['Next-Token\nPrediction', 'Masked LM']
    losses = [ntp_loss, mlm_loss]
    coverage = [100, np.mean(mask) * 100]  # Percentage of tokens used
    
    x = np.arange(len(objectives))
    width = 0.35
    
    ax1 = plt.gca()
    bars1 = ax1.bar(x - width/2, losses, width, label='Loss', alpha=0.7)
    ax1.set_ylabel('Loss Value')
    ax1.set_title('Objective Comparison')
    
    ax2 = ax1.twinx()
    bars2 = ax2.bar(x + width/2, coverage, width, label='Token Coverage %', alpha=0.7, color='orange')
    ax2.set_ylabel('Token Coverage (%)')
    
    ax1.set_xticks(x)
    ax1.set_xticklabels(objectives)
    
    # Add legends
    ax1.legend(loc='upper left')
    ax2.legend(loc='upper right')
    
    plt.tight_layout()
    plt.show()
    
    print("\nKey Differences:")
    print("Next-Token Prediction (GPT-style):")
    print("  - Uses all tokens for training")
    print("  - Autoregressive generation")
    print("  - Causal (left-to-right) attention")
    
    print("\nMasked Language Modeling (BERT-style):")
    print("  - Uses only masked tokens (~15%)")
    print("  - Bidirectional context")
    print("  - Better for understanding tasks")

demonstrate_pretraining_objectives()

## 2. Data Preprocessing Pipeline

In [None]:
class DataPreprocessor:
    """Data preprocessing pipeline for LLM training."""
    
    def __init__(self):
        self.stats = defaultdict(int)
    
    def normalize_text(self, text: str) -> str:
        """Basic text normalization."""
        # Convert to lowercase
        text = text.lower()
        
        # Normalize whitespace
        text = re.sub(r'\s+', ' ', text)
        text = text.strip()
        
        # Handle special characters
        text = re.sub(r"[\"'„‚]", '"', text)  # Normalize quotes
        text = re.sub(r'[–—]', '-', text)  # Normalize dashes
        
        return text
    
    def detect_language(self, text: str) -> str:
        """Simple language detection (simplified)."""
        # Count character frequencies for basic detection
        char_counts = defaultdict(int)
        for char in text.lower():
            if char.isalpha():
                char_counts[char] += 1
        
        # Simple heuristics
        total_chars = sum(char_counts.values())
        if total_chars == 0:
            return 'unknown'
        
        # Check for common English patterns
        english_chars = char_counts['e'] + char_counts['t'] + char_counts['a']
        if english_chars / total_chars > 0.25:
            return 'en'
        
        return 'other'
    
    def quality_filter(self, text: str) -> bool:
        """Filter low-quality text."""
        # Length checks
        if len(text) < 50 or len(text) > 10000:
            self.stats['filtered_length'] += 1
            return False
        
        # Character diversity
        unique_chars = len(set(text.lower()))
        if unique_chars < 10:
            self.stats['filtered_diversity'] += 1
            return False
        
        # Repetition check
        words = text.split()
        if len(words) > 0:
            word_counts = defaultdict(int)
            for word in words:
                word_counts[word] += 1
            
            max_repetition = max(word_counts.values())
            if max_repetition > len(words) * 0.3:  # More than 30% repetition
                self.stats['filtered_repetition'] += 1
                return False
        
        self.stats['passed_quality'] += 1
        return True
    
    def deduplicate(self, texts: List[str], threshold: float = 0.8) -> List[str]:
        """Simple deduplication based on character overlap."""
        def jaccard_similarity(text1: str, text2: str) -> float:
            set1 = set(text1.lower().split())
            set2 = set(text2.lower().split())
            intersection = len(set1.intersection(set2))
            union = len(set1.union(set2))
            return intersection / union if union > 0 else 0
        
        deduplicated = []
        for text in texts:
            is_duplicate = False
            for existing in deduplicated:
                if jaccard_similarity(text, existing) > threshold:
                    is_duplicate = True
                    self.stats['duplicates_removed'] += 1
                    break
            
            if not is_duplicate:
                deduplicated.append(text)
        
        return deduplicated
    
    def process_batch(self, texts: List[str]) -> List[str]:
        """Process a batch of texts."""
        processed = []
        
        for text in texts:
            # Normalize
            normalized = self.normalize_text(text)
            
            # Language detection
            lang = self.detect_language(normalized)
            if lang != 'en':
                self.stats['filtered_language'] += 1
                continue
            
            # Quality filter
            if not self.quality_filter(normalized):
                continue
            
            processed.append(normalized)
        
        # Deduplicate
        processed = self.deduplicate(processed)
        
        return processed

def demonstrate_data_preprocessing():
    """Demonstrate data preprocessing pipeline."""
    
    # Sample raw texts with various quality issues
    raw_texts = [
        "This is a high-quality text with proper grammar and structure. It contains meaningful content.",
        "short text",  # Too short
        "This is a high-quality text with proper grammar and structure. It contains meaningful content.",  # Duplicate
        "aaaaa aaaaa aaaaa aaaaa aaaaa aaaaa aaaaa aaaaa",  # Too repetitive
        "Another good quality text that provides valuable information and insights about various topics.",
        "MIXED    CASE   TEXT   WITH    WEIRD     SPACING!!!",
        "The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
        "x" * 15000,  # Too long
        "Bonjour, comment allez-vous? Je suis très bien.",  # Non-English
        "Final text with good quality and reasonable length for training purposes."
    ]
    
    # Process the data
    preprocessor = DataPreprocessor()
    processed_texts = preprocessor.process_batch(raw_texts)
    
    # Display results
    print("Data Preprocessing Results:")
    print(f"Input texts: {len(raw_texts)}")
    print(f"Output texts: {len(processed_texts)}")
    print(f"Retention rate: {len(processed_texts)/len(raw_texts)*100:.1f}%")
    
    print("\nFiltering Statistics:")
    for reason, count in preprocessor.stats.items():
        print(f"  {reason}: {count}")
    
    # Visualize preprocessing pipeline
    plt.figure(figsize=(15, 10))
    
    # Pipeline flow
    plt.subplot(2, 3, 1)
    stages = ['Raw\nTexts', 'Normalized', 'Language\nFiltered', 'Quality\nFiltered', 'Deduplicated']
    counts = [
        len(raw_texts),
        len(raw_texts),  # Normalization doesn't remove texts
        len(raw_texts) - preprocessor.stats['filtered_language'],
        preprocessor.stats['passed_quality'],
        len(processed_texts)
    ]
    
    plt.plot(stages, counts, 'o-', linewidth=2, markersize=8)
    plt.title('Data Processing Pipeline')
    plt.ylabel('Number of Texts')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Text length distribution
    plt.subplot(2, 3, 2)
    raw_lengths = [len(text) for text in raw_texts]
    processed_lengths = [len(text) for text in processed_texts]
    
    plt.hist(raw_lengths, bins=10, alpha=0.7, label='Raw', density=True)
    plt.hist(processed_lengths, bins=10, alpha=0.7, label='Processed', density=True)
    plt.title('Text Length Distribution')
    plt.xlabel('Text Length (characters)')
    plt.ylabel('Density')
    plt.legend()
    
    # Quality metrics
    plt.subplot(2, 3, 3)
    filter_reasons = list(preprocessor.stats.keys())
    filter_counts = list(preprocessor.stats.values())
    
    plt.pie(filter_counts, labels=filter_reasons, autopct='%1.1f%%')
    plt.title('Filtering Breakdown')
    
    # Show sample processed texts
    plt.subplot(2, 3, 4)
    plt.text(0.1, 0.9, "Sample Processed Texts:", fontsize=12, weight='bold')
    
    y_pos = 0.8
    for i, text in enumerate(processed_texts[:3]):
        truncated = text[:50] + "..." if len(text) > 50 else text
        plt.text(0.1, y_pos, f"{i+1}: {truncated}", fontsize=9, wrap=True)
        y_pos -= 0.2
    
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.axis('off')
    
    # Character diversity analysis
    plt.subplot(2, 3, 5)
    diversities = [len(set(text.lower())) for text in processed_texts]
    plt.hist(diversities, bins=5, alpha=0.7)
    plt.title('Character Diversity')
    plt.xlabel('Unique Characters')
    plt.ylabel('Frequency')
    
    # Processing efficiency
    plt.subplot(2, 3, 6)
    efficiency_metrics = {
        'Retention Rate': len(processed_texts)/len(raw_texts)*100,
        'Quality Rate': preprocessor.stats['passed_quality']/len(raw_texts)*100,
        'Dedup Rate': (1 - preprocessor.stats['duplicates_removed']/len(raw_texts))*100
    }
    
    metrics = list(efficiency_metrics.keys())
    values = list(efficiency_metrics.values())
    
    plt.bar(metrics, values, alpha=0.7)
    plt.title('Processing Efficiency')
    plt.ylabel('Percentage')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print("\nPreprocessing Pipeline Components:")
    print("1. Text Normalization: Standardize format and encoding")
    print("2. Language Detection: Filter non-target languages")
    print("3. Quality Filtering: Remove low-quality content")
    print("4. Deduplication: Remove duplicate or near-duplicate content")
    print("5. Format Validation: Ensure proper structure")

demonstrate_data_preprocessing()

## 3. Token Batching and Padding

In [None]:
class TokenBatcher:
    """Efficient token batching with padding strategies."""
    
    def __init__(self, pad_token_id: int = 0):
        self.pad_token_id = pad_token_id
    
    def static_batching(self, sequences: List[List[int]], batch_size: int, 
                       max_length: int) -> List[Tuple[np.ndarray, np.ndarray]]:
        """Static batching with fixed sequence length."""
        batches = []
        
        for i in range(0, len(sequences), batch_size):
            batch_sequences = sequences[i:i + batch_size]
            
            # Pad sequences to max_length
            padded_batch = []
            attention_masks = []
            
            for seq in batch_sequences:
                if len(seq) > max_length:
                    # Truncate
                    padded_seq = seq[:max_length]
                    mask = [1] * max_length
                else:
                    # Pad
                    padded_seq = seq + [self.pad_token_id] * (max_length - len(seq))
                    mask = [1] * len(seq) + [0] * (max_length - len(seq))
                
                padded_batch.append(padded_seq)
                attention_masks.append(mask)
            
            batches.append((np.array(padded_batch), np.array(attention_masks)))
        
        return batches
    
    def dynamic_batching(self, sequences: List[List[int]], 
                        max_tokens: int) -> List[Tuple[np.ndarray, np.ndarray]]:
        """Dynamic batching based on token count."""
        # Sort sequences by length for efficient packing
        sorted_sequences = sorted(sequences, key=len)
        
        batches = []
        current_batch = []
        current_tokens = 0
        
        for seq in sorted_sequences:
            seq_length = len(seq)
            
            # Calculate tokens needed if we add this sequence
            if current_batch:
                max_len_in_batch = max(len(s) for s in current_batch + [seq])
                tokens_needed = max_len_in_batch * (len(current_batch) + 1)
            else:
                tokens_needed = seq_length
            
            if tokens_needed <= max_tokens:
                current_batch.append(seq)
                current_tokens = tokens_needed
            else:
                # Finalize current batch
                if current_batch:
                    batches.append(self._create_padded_batch(current_batch))
                
                # Start new batch
                current_batch = [seq]
                current_tokens = seq_length
        
        # Add final batch
        if current_batch:
            batches.append(self._create_padded_batch(current_batch))
        
        return batches
    
    def sequence_packing(self, sequences: List[List[int]], 
                        max_length: int) -> List[Tuple[np.ndarray, np.ndarray]]:
        """Pack multiple sequences into single examples."""
        packed_sequences = []
        current_packed = []
        current_length = 0
        
        for seq in sequences:
            if current_length + len(seq) + 1 <= max_length:  # +1 for separator
                if current_packed:
                    current_packed.extend([self.pad_token_id])  # Separator
                    current_length += 1
                
                current_packed.extend(seq)
                current_length += len(seq)
            else:
                # Finalize current packed sequence
                if current_packed:
                    packed_sequences.append(current_packed)
                
                # Start new packed sequence
                current_packed = seq.copy()
                current_length = len(seq)
        
        # Add final packed sequence
        if current_packed:
            packed_sequences.append(current_packed)
        
        # Create batches from packed sequences
        return self.static_batching(packed_sequences, batch_size=32, max_length=max_length)
    
    def _create_padded_batch(self, sequences: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
        """Create padded batch from sequences."""
        max_len = max(len(seq) for seq in sequences)
        
        padded_batch = []
        attention_masks = []
        
        for seq in sequences:
            padded_seq = seq + [self.pad_token_id] * (max_len - len(seq))
            mask = [1] * len(seq) + [0] * (max_len - len(seq))
            
            padded_batch.append(padded_seq)
            attention_masks.append(mask)
        
        return np.array(padded_batch), np.array(attention_masks)

def demonstrate_token_batching():
    """Demonstrate different batching strategies."""
    
    # Generate sample sequences with varying lengths
    np.random.seed(42)
    sequences = []
    for _ in range(20):
        length = np.random.randint(10, 100)
        seq = np.random.randint(1, 1000, length).tolist()
        sequences.append(seq)
    
    batcher = TokenBatcher(pad_token_id=0)
    
    # Compare different batching strategies
    static_batches = batcher.static_batching(sequences, batch_size=4, max_length=80)
    dynamic_batches = batcher.dynamic_batching(sequences, max_tokens=320)
    packed_batches = batcher.sequence_packing(sequences, max_length=150)
    
    # Calculate efficiency metrics
    def calculate_efficiency(batches):
        total_tokens = 0
        total_padding = 0
        
        for batch, mask in batches:
            total_tokens += np.sum(mask)
            total_padding += np.sum(1 - mask)
        
        efficiency = total_tokens / (total_tokens + total_padding) * 100
        return efficiency, total_tokens, total_padding
    
    static_eff, static_tokens, static_padding = calculate_efficiency(static_batches)
    dynamic_eff, dynamic_tokens, dynamic_padding = calculate_efficiency(dynamic_batches)
    packed_eff, packed_tokens, packed_padding = calculate_efficiency(packed_batches)
    
    print("Token Batching Comparison:")
    print(f"\nStatic Batching:")
    print(f"  Batches: {len(static_batches)}")
    print(f"  Efficiency: {static_eff:.1f}%")
    print(f"  Tokens: {static_tokens}, Padding: {static_padding}")
    
    print(f"\nDynamic Batching:")
    print(f"  Batches: {len(dynamic_batches)}")
    print(f"  Efficiency: {dynamic_eff:.1f}%")
    print(f"  Tokens: {dynamic_tokens}, Padding: {dynamic_padding}")
    
    print(f"\nSequence Packing:")
    print(f"  Batches: {len(packed_batches)}")
    print(f"  Efficiency: {packed_eff:.1f}%")
    print(f"  Tokens: {packed_tokens}, Padding: {packed_padding}")
    
    # Visualize batching strategies
    plt.figure(figsize=(15, 12))
    
    # Sequence length distribution
    plt.subplot(3, 3, 1)
    lengths = [len(seq) for seq in sequences]
    plt.hist(lengths, bins=10, alpha=0.7)
    plt.title('Sequence Length Distribution')
    plt.xlabel('Sequence Length')
    plt.ylabel('Frequency')
    
    # Efficiency comparison
    plt.subplot(3, 3, 2)
    strategies = ['Static', 'Dynamic', 'Packed']
    efficiencies = [static_eff, dynamic_eff, packed_eff]
    
    bars = plt.bar(strategies, efficiencies, alpha=0.7)
    plt.title('Batching Efficiency')
    plt.ylabel('Efficiency (%)')
    
    # Add value labels on bars
    for bar, eff in zip(bars, efficiencies):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                f'{eff:.1f}%', ha='center')
    
    # Batch size distribution
    plt.subplot(3, 3, 3)
    batch_sizes = {
        'Static': [len(batch) for batch, _ in static_batches],
        'Dynamic': [len(batch) for batch, _ in dynamic_batches],
        'Packed': [len(batch) for batch, _ in packed_batches]
    }
    
    for strategy, sizes in batch_sizes.items():
        plt.hist(sizes, alpha=0.5, label=strategy, bins=5)
    
    plt.title('Batch Size Distribution')
    plt.xlabel('Batch Size')
    plt.ylabel('Frequency')
    plt.legend()
    
    # Visualize sample batches
    strategies_data = [
        ('Static Batching', static_batches[0]),
        ('Dynamic Batching', dynamic_batches[0]),
        ('Packed Batching', packed_batches[0])
    ]
    
    for i, (name, (batch, mask)) in enumerate(strategies_data):
        plt.subplot(3, 3, 4 + i)
        plt.imshow(mask, cmap='RdYlBu', aspect='auto')
        plt.title(f'{name}\nSample Batch Mask')
        plt.xlabel('Sequence Position')
        plt.ylabel('Batch Item')
        plt.colorbar(label='Token (1) / Padding (0)')
    
    # Token utilization
    plt.subplot(3, 3, 7)
    utilization_data = {
        'Static': [static_tokens, static_padding],
        'Dynamic': [dynamic_tokens, dynamic_padding],
        'Packed': [packed_tokens, packed_padding]
    }
    
    x = np.arange(len(strategies))
    width = 0.35
    
    tokens_data = [data[0] for data in utilization_data.values()]
    padding_data = [data[1] for data in utilization_data.values()]
    
    plt.bar(x, tokens_data, width, label='Tokens', alpha=0.7)
    plt.bar(x, padding_data, width, bottom=tokens_data, label='Padding', alpha=0.7)
    
    plt.title('Token vs Padding')
    plt.xlabel('Strategy')
    plt.ylabel('Count')
    plt.xticks(x, strategies)
    plt.legend()
    
    # Memory usage comparison
    plt.subplot(3, 3, 8)
    memory_usage = {
        'Static': sum(batch.size for batch, _ in static_batches),
        'Dynamic': sum(batch.size for batch, _ in dynamic_batches),
        'Packed': sum(batch.size for batch, _ in packed_batches)
    }
    
    plt.bar(memory_usage.keys(), memory_usage.values(), alpha=0.7)
    plt.title('Memory Usage')
    plt.ylabel('Total Elements')
    plt.xticks(rotation=45)
    
    # Batch count comparison
    plt.subplot(3, 3, 9)
    batch_counts = [len(static_batches), len(dynamic_batches), len(packed_batches)]
    
    plt.bar(strategies, batch_counts, alpha=0.7)
    plt.title('Number of Batches')
    plt.ylabel('Batch Count')
    
    plt.tight_layout()
    plt.show()
    
    print("\nBatching Strategy Trade-offs:")
    print("\nStatic Batching:")
    print("  + Simple implementation")
    print("  + Predictable memory usage")
    print("  - High padding overhead")
    
    print("\nDynamic Batching:")
    print("  + Better padding efficiency")
    print("  + Adaptive batch sizes")
    print("  - More complex implementation")
    
    print("\nSequence Packing:")
    print("  + Highest efficiency")
    print("  + Minimal padding")
    print("  - Requires careful attention masking")
    print("  - More complex data loading")

demonstrate_token_batching()