# Focused Learning: Training Stabilization Techniques (Z-loss, AdamW)

## Learning Objective
Deep dive into the advanced training stabilization techniques used in the Kotlin ML Pack paper, focusing on Z-loss, modified AdamW parameters, and other stability methods from Section V.C.

## Paper Reference
- **Section**: V.C - Training setup
- **Key Techniques**: Z-loss, decreased AdamW epsilon, gradient clipping, extended warm-up
- **Goal**: Prevent training instabilities and logit divergence in large language models

## 1. Understanding Training Instabilities

### 1.1 The Problem: Logit Divergence

From the paper: "The instability may happen closer to the end of the training, when the logits become very negative."

This is a critical issue in training large language models where:
- Output logits can grow unbounded
- Softmax denominators explode
- Gradients become unstable
- Model performance degrades catastrophically

In [None]:
# Install required packages
!pip install torch numpy matplotlib seaborn pandas

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import pandas as pd
from dataclasses import dataclass

# Set up visualization
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Z-Loss: The Core Innovation

### 2.1 Mathematical Foundation

The Z-loss is defined as:

$$\mathcal{L}_z = \log^2(Z)$$

where $Z = \sum_j \exp(y_j)$ is the partition function (softmax denominator).

The key insight: By penalizing large values of $\log(Z)$, we prevent logits from growing too large.

In [None]:
class ZLoss(nn.Module):
    """Z-loss implementation from Section V.C of the paper.
    
    This loss prevents logit divergence by penalizing large partition functions.
    """
    
    def __init__(self):
        super().__init__()
        self.loss_history = []  # Track loss values for analysis
    
    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Calculate Z-loss for given logits.
        
        Args:
            logits: Model output logits [batch_size, seq_len, vocab_size]
        
        Returns:
            Z-loss value
        """
        # Calculate log(sum(exp(logits))) using logsumexp for numerical stability
        log_z = torch.logsumexp(logits, dim=-1)
        
        # Z-loss is log^2(Z)
        z_loss = log_z ** 2
        
        # Store for analysis
        self.loss_history.append(z_loss.mean().item())
        
        return z_loss.mean()
    
    def analyze_impact(self, logits: torch.Tensor) -> Dict[str, float]:
        """Analyze the impact of Z-loss on logit distribution"""
        with torch.no_grad():
            # Calculate various statistics
            log_z = torch.logsumexp(logits, dim=-1)
            max_logit = logits.max(dim=-1)[0]
            min_logit = logits.min(dim=-1)[0]
            logit_range = max_logit - min_logit
            
            return {
                'mean_log_z': log_z.mean().item(),
                'max_log_z': log_z.max().item(),
                'mean_max_logit': max_logit.mean().item(),
                'mean_logit_range': logit_range.mean().item(),
                'z_loss': (log_z ** 2).mean().item()
            }

# Demonstrate Z-loss behavior
z_loss_fn = ZLoss()

# Create example logits with different scales
batch_size, seq_len, vocab_size = 8, 128, 32000
scales = [0.1, 1.0, 10.0, 100.0]
results = []

for scale in scales:
    logits = torch.randn(batch_size, seq_len, vocab_size) * scale
    z_loss = z_loss_fn(logits)
    stats = z_loss_fn.analyze_impact(logits)
    stats['scale'] = scale
    results.append(stats)

# Visualize results
df = pd.DataFrame(results)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Z-loss vs scale
axes[0, 0].plot(df['scale'], df['z_loss'], 'o-', color='red', markersize=8)
axes[0, 0].set_xlabel('Logit Scale')
axes[0, 0].set_ylabel('Z-loss')
axes[0, 0].set_title('Z-loss vs Logit Scale')
axes[0, 0].set_xscale('log')
axes[0, 0].set_yscale('log')
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Log(Z) vs scale
axes[0, 1].plot(df['scale'], df['mean_log_z'], 'o-', color='blue', markersize=8)
axes[0, 1].set_xlabel('Logit Scale')
axes[0, 1].set_ylabel('Mean log(Z)')
axes[0, 1].set_title('Partition Function Growth')
axes[0, 1].set_xscale('log')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Max logit vs scale
axes[1, 0].plot(df['scale'], df['mean_max_logit'], 'o-', color='green', markersize=8)
axes[1, 0].set_xlabel('Logit Scale')
axes[1, 0].set_ylabel('Mean Max Logit')
axes[1, 0].set_title('Maximum Logit Growth')
axes[1, 0].set_xscale('log')
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Logit range vs scale
axes[1, 1].plot(df['scale'], df['mean_logit_range'], 'o-', color='purple', markersize=8)
axes[1, 1].set_xlabel('Logit Scale')
axes[1, 1].set_ylabel('Mean Logit Range')
axes[1, 1].set_title('Logit Range (max - min)')
axes[1, 1].set_xscale('log')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Key Insight: Z-loss grows quadratically with logit scale, providing strong regularization")

## 3. Modified AdamW: Decreased Epsilon

### 3.1 The Paper's Finding

From Section V.C: "We find that in our case, the setting of ε = 10^-16 slightly improves both train loss and downstream benchmark scores at no extra costs."

This is a significant deviation from the default ε = 10^-8. Let's understand why.

In [None]:
class ModifiedAdamW:
    """Demonstrates the impact of different epsilon values in AdamW"""
    
    def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.weight_decay = weight_decay
        
        # Initialize state
        self.state = {}
        for p in self.params:
            self.state[p] = {
                'm': torch.zeros_like(p.data),  # First moment
                'v': torch.zeros_like(p.data),  # Second moment
                't': 0  # Time step
            }
    
    def step(self, closure=None):
        """Perform a single optimization step"""
        loss = None
        if closure is not None:
            loss = closure()
        
        for p in self.params:
            if p.grad is None:
                continue
            
            grad = p.grad.data
            state = self.state[p]
            
            # Update biased first moment estimate
            state['m'] = self.beta1 * state['m'] + (1 - self.beta1) * grad
            
            # Update biased second raw moment estimate
            state['v'] = self.beta2 * state['v'] + (1 - self.beta2) * grad**2
            
            # Update time step
            state['t'] += 1
            
            # Bias correction
            m_hat = state['m'] / (1 - self.beta1**state['t'])
            v_hat = state['v'] / (1 - self.beta2**state['t'])
            
            # AdamW update with weight decay
            p.data = p.data - self.lr * self.weight_decay * p.data
            
            # Adam update
            # This is where epsilon matters!
            update = self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)
            p.data = p.data - update
        
        return loss
    
    def analyze_update_magnitude(self, grad, second_moment):
        """Analyze how epsilon affects update magnitude"""
        # Simulate the denominator calculation
        denominators = {}
        epsilons = [1e-8, 1e-12, 1e-16]  # Different epsilon values
        
        for eps in epsilons:
            denom = torch.sqrt(second_moment) + eps
            update_scale = 1.0 / denom
            denominators[eps] = {
                'min': update_scale.min().item(),
                'max': update_scale.max().item(),
                'mean': update_scale.mean().item()
            }
        
        return denominators

# Demonstrate the impact of epsilon
def demonstrate_epsilon_impact():
    """Show how different epsilon values affect gradient updates"""
    
    # Create example gradients and second moments
    size = (1000,)
    
    # Case 1: Very small second moments (common in later training)
    small_second_moments = torch.tensor(np.random.exponential(1e-10, size))
    grad = torch.randn(size) * 0.01
    
    # Analyze update scales
    epsilons = [1e-8, 1e-12, 1e-16]
    results = []
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for i, eps in enumerate(epsilons):
        denom = torch.sqrt(small_second_moments) + eps
        update_scale = 1.0 / denom
        actual_update = grad * update_scale
        
        # Plot distribution
        ax = axes[i]
        ax.hist(np.log10(update_scale.numpy()), bins=50, alpha=0.7, color='blue')
        ax.set_xlabel('log10(Update Scale)')
        ax.set_ylabel('Frequency')
        ax.set_title(f'ε = {eps}')
        ax.grid(True, alpha=0.3)
        
        # Add statistics
        ax.text(0.05, 0.95, f'Max: {update_scale.max():.2e}\nMin: {update_scale.min():.2e}', 
                transform=ax.transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        results.append({
            'epsilon': eps,
            'max_scale': update_scale.max().item(),
            'min_scale': update_scale.min().item(),
            'mean_scale': update_scale.mean().item(),
            'max_update': actual_update.abs().max().item()
        })
    
    plt.tight_layout()
    plt.show()
    
    # Summary table
    df = pd.DataFrame(results)
    print("\nImpact of Epsilon on Update Magnitudes:")
    print("=" * 60)
    print(df.to_string(index=False))
    
    print("\nKey Insight: Smaller epsilon allows larger updates when second moments are tiny")
    print("This prevents 'dead' parameters that stop updating in later training stages.")

demonstrate_epsilon_impact()

## 4. Gradient Norm Clipping

### 4.1 The Balance

From the paper: "We choose gradient clipping so that very few gradients are clipped, avoiding the effective decrease of the learning rate caused by aggressive gradient clipping."

In [None]:
class GradientClippingAnalyzer:
    """Analyzes the impact of gradient clipping on training"""
    
    def __init__(self, max_norm=1.0):
        self.max_norm = max_norm
        self.clip_history = []
    
    def clip_grad_norm_(self, parameters, max_norm=None):
        """Custom gradient clipping with detailed tracking"""
        if max_norm is None:
            max_norm = self.max_norm
        
        # Calculate total gradient norm
        total_norm = 0.0
        for p in parameters:
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        # Calculate clipping coefficient
        clip_coef = max_norm / (total_norm + 1e-6)
        clip_coef = min(clip_coef, 1.0)
        
        # Apply clipping
        if clip_coef < 1.0:
            for p in parameters:
                if p.grad is not None:
                    p.grad.data.mul_(clip_coef)
        
        # Track statistics
        self.clip_history.append({
            'total_norm': total_norm,
            'clip_coef': clip_coef,
            'clipped': clip_coef < 1.0
        })
        
        return total_norm
    
    def analyze_clipping_impact(self, gradient_norms: List[float], clip_values: List[float]):
        """Analyze how different clipping values affect gradients"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Convert to numpy for easier manipulation
        gradient_norms = np.array(gradient_norms)
        
        for i, clip_value in enumerate(clip_values):
            ax = axes[i // 2, i % 2]
            
            # Calculate clipped norms
            clipped_norms = np.minimum(gradient_norms, clip_value)
            clip_ratio = np.sum(gradient_norms > clip_value) / len(gradient_norms)
            
            # Plot histogram
            ax.hist(gradient_norms, bins=50, alpha=0.5, label='Original', color='blue')
            ax.hist(clipped_norms, bins=50, alpha=0.5, label='Clipped', color='red')
            ax.axvline(clip_value, color='black', linestyle='--', label=f'Clip threshold')
            
            ax.set_xlabel('Gradient Norm')
            ax.set_ylabel('Frequency')
            ax.set_title(f'Clip Value: {clip_value} (Clipped: {clip_ratio*100:.1f}%)')
            ax.legend()
            ax.set_xlim(0, max(gradient_norms) * 1.1)
            
        plt.tight_layout()
        plt.show()

# Simulate gradient norms during training
def simulate_training_gradients(num_steps=1000):
    """Simulate gradient norms during training"""
    # Early training: high variance, occasional spikes
    early_grads = np.random.lognormal(mean=0.0, sigma=1.0, size=num_steps//3)
    
    # Mid training: more stable
    mid_grads = np.random.lognormal(mean=-0.5, sigma=0.5, size=num_steps//3)
    
    # Late training: very stable, small gradients
    late_grads = np.random.lognormal(mean=-1.0, sigma=0.3, size=num_steps//3)
    
    # Add occasional spikes
    all_grads = np.concatenate([early_grads, mid_grads, late_grads])
    spike_indices = np.random.choice(len(all_grads), size=20, replace=False)
    all_grads[spike_indices] *= np.random.uniform(5, 20, size=20)
    
    return all_grads

# Analyze gradient clipping
gradient_norms = simulate_training_gradients()
analyzer = GradientClippingAnalyzer()

# Test different clipping values
clip_values = [0.5, 1.0, 5.0, 10.0]
analyzer.analyze_clipping_impact(gradient_norms, clip_values)

# Recommendation based on paper
print("\nPaper's Approach: Choose clipping threshold where <5% of gradients are clipped")
percentiles = [90, 95, 99, 99.9]
print("\nGradient Norm Percentiles:")
for p in percentiles:
    value = np.percentile(gradient_norms, p)
    print(f"  {p}th percentile: {value:.2f}")
print("\nRecommendation: Set clip_value around 95th-99th percentile")

## 5. Extended Warm-up Period

### 5.1 The Paper's Finding

"Using warm-up period length of up to 10% of the train dataset allows training models at a higher learning rate without facing instabilities."

In [None]:
class WarmupScheduler:
    """Implements various warm-up strategies for learning rate scheduling"""
    
    def __init__(self, optimizer, warmup_steps, max_lr, schedule_type='linear'):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.max_lr = max_lr
        self.schedule_type = schedule_type
        self.current_step = 0
    
    def get_lr(self, step=None):
        """Calculate learning rate for given step"""
        if step is None:
            step = self.current_step
        
        if step >= self.warmup_steps:
            return self.max_lr
        
        warmup_ratio = step / self.warmup_steps
        
        if self.schedule_type == 'linear':
            return self.max_lr * warmup_ratio
        elif self.schedule_type == 'exponential':
            return self.max_lr * (warmup_ratio ** 2)
        elif self.schedule_type == 'cosine':
            return self.max_lr * 0.5 * (1 + np.cos(np.pi * (1 - warmup_ratio)))
        else:
            raise ValueError(f"Unknown schedule type: {self.schedule_type}")
    
    def step(self):
        """Update learning rate"""
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

def visualize_warmup_impact():
    """Visualize the impact of different warm-up strategies"""
    total_steps = 10000
    warmup_percentages = [1, 5, 10, 20]  # Percentage of total steps
    max_lr = 1e-3
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.ravel()
    
    for i, warmup_pct in enumerate(warmup_percentages):
        ax = axes[i]
        warmup_steps = int(total_steps * warmup_pct / 100)
        
        # Create dummy optimizer
        param = torch.nn.Parameter(torch.randn(1))
        optimizer = torch.optim.Adam([param], lr=0)
        
        # Test different schedules
        schedules = ['linear', 'exponential', 'cosine']
        colors = ['blue', 'red', 'green']
        
        for schedule, color in zip(schedules, colors):
            scheduler = WarmupScheduler(optimizer, warmup_steps, max_lr, schedule)
            lrs = [scheduler.get_lr(step) for step in range(total_steps)]
            
            # Plot learning rate schedule
            ax.plot(lrs[:warmup_steps*2], label=schedule, color=color, linewidth=2)
        
        ax.axvline(warmup_steps, color='black', linestyle='--', alpha=0.5, label='End of warmup')
        ax.set_xlabel('Training Step')
        ax.set_ylabel('Learning Rate')
        ax.set_title(f'Warmup: {warmup_pct}% of training ({warmup_steps} steps)')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xlim(0, warmup_steps * 2)
    
    plt.tight_layout()
    plt.show()
    
    # Demonstrate impact on gradient variance
    print("\nImpact of Warm-up on Training Stability:")
    print("=" * 50)
    print("Without warm-up: High initial LR can cause gradient explosion")
    print("With 10% warm-up: Gradual increase prevents instabilities")
    print("Paper finding: Enables higher max LR → faster convergence")

visualize_warmup_impact()

## 6. Complete Training Configuration

Let's implement the complete training setup from the paper, combining all techniques.

In [None]:
class StabilizedTrainer:
    """Complete implementation of training stabilization techniques from the paper"""
    
    def __init__(self, model, config):
        self.model = model
        self.config = config
        
        # Initialize components
        self.z_loss = ZLoss()
        
        # Optimizer with decreased epsilon
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay'],
            eps=config['adam_epsilon']  # 1e-16 from paper
        )
        
        # Warmup scheduler
        self.warmup_steps = int(config['total_steps'] * config['warmup_ratio'])
        self.scheduler = WarmupScheduler(
            self.optimizer,
            self.warmup_steps,
            config['learning_rate'],
            'linear'
        )
        
        # Tracking
        self.training_history = {
            'loss': [],
            'z_loss': [],
            'grad_norm': [],
            'learning_rate': [],
            'clipped': []
        }
    
    def compute_loss(self, logits, labels):
        """Compute combined loss with Z-loss regularization"""
        # Cross-entropy loss
        ce_loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1)
        )
        
        # Z-loss
        z_loss = self.z_loss(logits)
        
        # Combined loss
        total_loss = ce_loss + self.config['z_loss_weight'] * z_loss
        
        return total_loss, ce_loss.item(), z_loss.item()
    
    def training_step(self, batch):
        """Single training step with all stabilization techniques"""
        # Forward pass
        logits = self.model(batch['input_ids'])
        
        # Compute loss
        loss, ce_loss, z_loss = self.compute_loss(logits, batch['labels'])
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.config['max_grad_norm']
        )
        
        # Optimizer step
        self.optimizer.step()
        
        # Learning rate scheduling
        current_lr = self.scheduler.step()
        
        # Track metrics
        self.training_history['loss'].append(ce_loss)
        self.training_history['z_loss'].append(z_loss)
        self.training_history['grad_norm'].append(grad_norm.item())
        self.training_history['learning_rate'].append(current_lr)
        self.training_history['clipped'].append(grad_norm > self.config['max_grad_norm'])
        
        return {
            'loss': ce_loss,
            'z_loss': z_loss,
            'grad_norm': grad_norm.item(),
            'lr': current_lr
        }

# Configuration based on the paper
training_config = {
    'learning_rate': 1e-4,
    'weight_decay': 0.1,  # "quite large" as per paper
    'adam_epsilon': 1e-16,  # Decreased from default
    'z_loss_weight': 0.01,
    'max_grad_norm': 1.0,
    'warmup_ratio': 0.1,  # 10% of training
    'total_steps': 10000
}

print("Training Configuration (from paper):")
print("=" * 50)
for key, value in training_config.items():
    print(f"{key:<20}: {value}")

# Visualize the complete training dynamics
def simulate_training_run():
    """Simulate a training run with stabilization techniques"""
    # Create dummy model
    class DummyModel(nn.Module):
        def __init__(self, vocab_size=32000, hidden_size=768):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size, hidden_size)
            self.output = nn.Linear(hidden_size, vocab_size)
        
        def forward(self, input_ids):
            embeds = self.embedding(input_ids)
            return self.output(embeds)
    
    model = DummyModel()
    trainer = StabilizedTrainer(model, training_config)
    
    # Simulate training steps
    for step in range(1000):
        # Create dummy batch
        batch = {
            'input_ids': torch.randint(0, 32000, (8, 128)),
            'labels': torch.randint(0, 32000, (8, 128))
        }
        
        metrics = trainer.training_step(batch)
    
    return trainer.training_history

# Run simulation
history = simulate_training_run()

# Visualize training dynamics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Loss curves
axes[0, 0].plot(history['loss'], label='CE Loss', alpha=0.8)
axes[0, 0].plot(history['z_loss'], label='Z-Loss', alpha=0.8)
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Losses')
axes[0, 0].legend()
axes[0, 0].set_yscale('log')
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Gradient norms
axes[0, 1].plot(history['grad_norm'], alpha=0.8, color='green')
axes[0, 1].axhline(training_config['max_grad_norm'], color='red', linestyle='--', label='Clip threshold')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Gradient Norm')
axes[0, 1].set_title('Gradient Norms')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Learning rate schedule
axes[1, 0].plot(history['learning_rate'], color='purple')
axes[1, 0].axvline(training_config['total_steps'] * training_config['warmup_ratio'],
                   color='black', linestyle='--', label='End of warmup')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Clipping frequency
window_size = 50
clipping_rate = pd.Series(history['clipped']).rolling(window_size).mean() * 100
axes[1, 1].plot(clipping_rate, color='orange')
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Clipping Rate (%)')
axes[1, 1].set_title(f'Gradient Clipping Rate ({window_size}-step window)')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTraining Statistics:")
print(f"Average gradient clipping rate: {np.mean(history['clipped']) * 100:.2f}%")
print(f"Max gradient norm: {max(history['grad_norm']):.2f}")
print(f"Final Z-loss: {history['z_loss'][-1]:.4f}")

## 7. Dynamic Beta (Failed Experiment)

The paper mentions trying dynamic β₂ but finding it doesn't help. Let's understand why.

In [None]:
def analyze_dynamic_beta():
    """Analyze the dynamic beta approach mentioned in the paper"""
    
    # Formula from paper: β₂ = 1 - k^(-0.8) where k is step number
    steps = np.arange(1, 10000)
    dynamic_beta2 = 1 - steps**(-0.8)
    static_beta2 = 0.999  # Standard AdamW value
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot beta2 values
    ax1.plot(steps, dynamic_beta2, label='Dynamic β₂', color='blue')
    ax1.axhline(static_beta2, color='red', linestyle='--', label='Static β₂ = 0.999')
    ax1.set_xlabel('Training Step')
    ax1.set_ylabel('β₂ Value')
    ax1.set_title('Dynamic vs Static β₂')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_xscale('log')
    
    # Plot effective window size (1/(1-β₂))
    dynamic_window = 1 / (1 - dynamic_beta2)
    static_window = 1 / (1 - static_beta2)
    
    ax2.plot(steps, dynamic_window, label='Dynamic β₂', color='blue')
    ax2.axhline(static_window, color='red', linestyle='--', label='Static β₂')
    ax2.set_xlabel('Training Step')
    ax2.set_ylabel('Effective Window Size')
    ax2.set_title('Second Moment Estimation Window')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_xscale('log')
    ax2.set_yscale('log')
    
    plt.tight_layout()
    plt.show()
    
    print("Analysis of Dynamic β₂:")
    print("=" * 50)
    print("Early training (step 10): β₂ = {:.4f}, window = {:.0f} steps".format(
        1 - 10**(-0.8), 1 / (1 - (1 - 10**(-0.8)))
    ))
    print("Mid training (step 1000): β₂ = {:.4f}, window = {:.0f} steps".format(
        1 - 1000**(-0.8), 1 / (1 - (1 - 1000**(-0.8)))
    ))
    print("\nPaper's finding: No improvement over static β₂")
    print("Likely reason: Kotlin fine-tuning doesn't have the rare token issue that PaLM faced")

analyze_dynamic_beta()

## 8. Summary: Best Practices for Stable Training

Based on the paper's findings, here's a complete recipe for stable LLM fine-tuning:

In [None]:
# Create a summary visualization
fig, ax = plt.subplots(figsize=(12, 8))

# Define techniques and their impacts
techniques = [
    "Z-Loss (weight=0.01)",
    "AdamW ε=1e-16",
    "Weight Decay=0.1",
    "Gradient Clipping (conservative)",
    "10% Warmup Period",
    "BF16 Precision"
]

impacts = [
    "Prevents logit divergence",
    "Enables updates for small gradients",
    "Additional regularization",
    "Handles gradient spikes",
    "Allows higher learning rates",
    "Reduces memory, maintains stability"
]

benefits = [
    "Stable training to completion",
    "Better final performance",
    "Prevents overfitting",
    "Robustness to outliers",
    "Faster convergence",
    "Efficient GPU utilization"
]

# Create table
table_data = []
for tech, impact, benefit in zip(techniques, impacts, benefits):
    table_data.append([tech, impact, benefit])

# Hide axes
ax.axis('tight')
ax.axis('off')

# Create table
table = ax.table(cellText=table_data,
                colLabels=['Technique', 'Primary Impact', 'Key Benefit'],
                cellLoc='left',
                loc='center',
                colWidths=[0.35, 0.35, 0.3])

# Style the table
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)

# Color header
for i in range(3):
    table[(0, i)].set_facecolor('#4CAF50')
    table[(0, i)].set_text_props(weight='bold', color='white')

# Alternate row colors
for i in range(1, len(techniques) + 1):
    for j in range(3):
        if i % 2 == 0:
            table[(i, j)].set_facecolor('#f0f0f0')

plt.title('Training Stabilization Techniques Summary', fontsize=16, fontweight='bold', pad=20)
plt.show()

print("\nKey Takeaways:")
print("=" * 50)
print("1. Z-loss is crucial for preventing catastrophic divergence")
print("2. Small changes (like ε=1e-16) can have significant impacts")
print("3. Conservative gradient clipping is better than aggressive")
print("4. Extended warmup enables more aggressive learning rates")
print("5. These techniques enabled successful fine-tuning on Kotlin datasets")
print("\nResult: Up to 16-point improvement on HumanEval benchmark!")