# 10 - Fine-Tuning Techniques

This notebook covers various fine-tuning approaches for language models.

## Topics Covered:
- Supervised fine-tuning
- Instruction tuning
- Parameter-efficient fine-tuning (LoRA, Adapters, Prefix tuning)
- Continual learning

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple

np.random.seed(42)

## 1. Parameter-Efficient Fine-Tuning

In [None]:
class LoRA:
    """Low-Rank Adaptation implementation."""
    
    def __init__(self, d_model: int, rank: int, alpha: float = 1.0):
        self.d_model = d_model
        self.rank = rank
        self.alpha = alpha
        
        # LoRA matrices: W = W0 + (alpha/rank) * B * A
        self.A = np.random.randn(rank, d_model) * 0.01
        self.B = np.zeros((d_model, rank))
        
        # Original frozen weights
        self.W0 = np.random.randn(d_model, d_model) * 0.1
    
    def forward(self, x: np.ndarray) -> np.ndarray:
        """Forward pass with LoRA adaptation."""
        # Original transformation
        original_output = x @ self.W0
        
        # LoRA adaptation
        lora_output = x @ self.B @ self.A
        
        # Combined output
        return original_output + (self.alpha / self.rank) * lora_output
    
    def get_trainable_params(self) -> int:
        """Number of trainable parameters."""
        return self.A.size + self.B.size
    
    def get_total_params(self) -> int:
        """Total parameters including frozen."""
        return self.W0.size + self.get_trainable_params()

class Adapter:
    """Adapter layer implementation."""
    
    def __init__(self, d_model: int, bottleneck_size: int):
        self.d_model = d_model
        self.bottleneck_size = bottleneck_size
        
        # Adapter weights
        self.W_down = np.random.randn(d_model, bottleneck_size) * 0.1
        self.W_up = np.random.randn(bottleneck_size, d_model) * 0.1
        self.b_down = np.zeros(bottleneck_size)
        self.b_up = np.zeros(d_model)
    
    def forward(self, x: np.ndarray) -> np.ndarray:
        """Forward pass through adapter."""
        # Down projection
        hidden = np.maximum(0, x @ self.W_down + self.b_down)  # ReLU
        
        # Up projection
        adapter_output = hidden @ self.W_up + self.b_up
        
        # Residual connection
        return x + adapter_output
    
    def get_trainable_params(self) -> int:
        """Number of trainable parameters."""
        return (self.W_down.size + self.W_up.size + 
                self.b_down.size + self.b_up.size)

class PrefixTuning:
    """Prefix tuning implementation."""
    
    def __init__(self, d_model: int, prefix_length: int, num_layers: int):
        self.d_model = d_model
        self.prefix_length = prefix_length
        self.num_layers = num_layers
        
        # Learnable prefix parameters for each layer
        self.prefix_keys = np.random.randn(num_layers, prefix_length, d_model) * 0.1
        self.prefix_values = np.random.randn(num_layers, prefix_length, d_model) * 0.1
    
    def get_prefix_kv(self, layer_idx: int) -> Tuple[np.ndarray, np.ndarray]:
        """Get prefix key-value pairs for a specific layer."""
        return self.prefix_keys[layer_idx], self.prefix_values[layer_idx]
    
    def forward_with_prefix(self, x: np.ndarray, layer_idx: int) -> np.ndarray:
        """Forward pass with prefix conditioning."""
        batch_size, seq_len, d_model = x.shape
        
        # Get prefix for this layer
        prefix_k, prefix_v = self.get_prefix_kv(layer_idx)
        
        # Expand prefix for batch
        prefix_k_batch = np.tile(prefix_k[None, :, :], (batch_size, 1, 1))
        prefix_v_batch = np.tile(prefix_v[None, :, :], (batch_size, 1, 1))
        
        # Concatenate prefix with input
        extended_keys = np.concatenate([prefix_k_batch, x], axis=1)
        extended_values = np.concatenate([prefix_v_batch, x], axis=1)
        
        # Simplified attention computation
        attention_weights = np.random.rand(batch_size, seq_len, seq_len + self.prefix_length)
        attention_weights = attention_weights / np.sum(attention_weights, axis=-1, keepdims=True)
        
        output = attention_weights @ extended_values
        
        return output
    
    def get_trainable_params(self) -> int:
        """Number of trainable parameters."""
        return self.prefix_keys.size + self.prefix_values.size

def demonstrate_peft_methods():
    """Demonstrate parameter-efficient fine-tuning methods."""
    
    d_model = 512
    batch_size, seq_len = 4, 16
    
    # Sample input
    x = np.random.randn(batch_size, seq_len, d_model)
    
    # Initialize PEFT methods
    lora = LoRA(d_model=d_model, rank=16, alpha=16)
    adapter = Adapter(d_model=d_model, bottleneck_size=64)
    prefix_tuning = PrefixTuning(d_model=d_model, prefix_length=10, num_layers=12)
    
    print("Parameter-Efficient Fine-Tuning Comparison:")
    
    # LoRA
    lora_output = lora.forward(x.reshape(-1, d_model)).reshape(batch_size, seq_len, d_model)
    lora_params = lora.get_trainable_params()
    lora_total = lora.get_total_params()
    
    print(f"\nLoRA (rank={lora.rank}):")
    print(f"  Trainable params: {lora_params:,}")
    print(f"  Total params: {lora_total:,}")
    print(f"  Efficiency: {lora_params/lora_total*100:.2f}% trainable")
    
    # Adapter
    adapter_output = adapter.forward(x)
    adapter_params = adapter.get_trainable_params()
    
    print(f"\nAdapter (bottleneck={adapter.bottleneck_size}):")
    print(f"  Trainable params: {adapter_params:,}")
    print(f"  Added per layer: {adapter_params:,}")
    
    # Prefix Tuning
    prefix_output = prefix_tuning.forward_with_prefix(x, layer_idx=0)
    prefix_params = prefix_tuning.get_trainable_params()
    
    print(f"\nPrefix Tuning (length={prefix_tuning.prefix_length}):")
    print(f"  Trainable params: {prefix_params:,}")
    print(f"  Params per layer: {prefix_params // prefix_tuning.num_layers:,}")
    
    # Full fine-tuning comparison
    full_model_params = 12 * d_model * d_model * 4  # Approximate transformer
    
    print(f"\nFull Fine-tuning:")
    print(f"  All params trainable: {full_model_params:,}")
    
    # Visualize efficiency
    plt.figure(figsize=(15, 10))
    
    # Parameter efficiency comparison
    plt.subplot(2, 3, 1)
    
    methods = ['Full FT', 'LoRA', 'Adapter', 'Prefix']
    trainable_params = [
        full_model_params,
        lora_params,
        adapter_params * 12,  # 12 layers
        prefix_params
    ]
    
    plt.bar(methods, trainable_params, alpha=0.7)
    plt.ylabel('Trainable Parameters')
    plt.title('Parameter Efficiency')
    plt.yscale('log')
    
    # Add efficiency percentages
    for i, (method, params) in enumerate(zip(methods, trainable_params)):
        efficiency = params / full_model_params * 100
        plt.text(i, params, f'{efficiency:.2f}%', ha='center', va='bottom')
    
    # LoRA rank analysis
    plt.subplot(2, 3, 2)
    
    ranks = [4, 8, 16, 32, 64]
    lora_params_by_rank = []
    
    for rank in ranks:
        params = 2 * rank * d_model  # A and B matrices
        lora_params_by_rank.append(params)
    
    plt.plot(ranks, lora_params_by_rank, 'o-', linewidth=2)
    plt.xlabel('LoRA Rank')
    plt.ylabel('Parameters per Layer')
    plt.title('LoRA Rank vs Parameters')
    plt.grid(True, alpha=0.3)
    
    # Adapter bottleneck analysis
    plt.subplot(2, 3, 3)
    
    bottlenecks = [32, 64, 128, 256, 512]
    adapter_params_by_bottleneck = []
    
    for bottleneck in bottlenecks:
        params = 2 * d_model * bottleneck + d_model + bottleneck
        adapter_params_by_bottleneck.append(params)
    
    plt.plot(bottlenecks, adapter_params_by_bottleneck, 's-', linewidth=2, color='orange')
    plt.xlabel('Bottleneck Size')
    plt.ylabel('Parameters per Layer')
    plt.title('Adapter Bottleneck vs Parameters')
    plt.grid(True, alpha=0.3)
    
    # Memory usage comparison
    plt.subplot(2, 3, 4)
    
    # Approximate memory usage (parameters * 4 bytes for fp32)
    memory_usage = [p * 4 / (1024**3) for p in trainable_params]  # GB
    
    plt.bar(methods, memory_usage, alpha=0.7, color='green')
    plt.ylabel('Memory Usage (GB)')
    plt.title('Training Memory Requirements')
    plt.yscale('log')
    
    # Performance vs efficiency trade-off
    plt.subplot(2, 3, 5)
    
    # Simulated performance scores
    performance_scores = [100, 95, 92, 88]  # Full FT = 100%
    efficiency_scores = [1, 99, 95, 97]     # % parameters saved
    
    plt.scatter(efficiency_scores, performance_scores, s=100, alpha=0.7)
    
    for i, method in enumerate(methods):
        plt.annotate(method, (efficiency_scores[i], performance_scores[i]), 
                    xytext=(5, 5), textcoords='offset points')
    
    plt.xlabel('Parameter Efficiency (%)')
    plt.ylabel('Performance Score')
    plt.title('Efficiency vs Performance')
    plt.grid(True, alpha=0.3)
    
    # Training speed comparison
    plt.subplot(2, 3, 6)
    
    # Relative training speeds (higher = faster)
    training_speeds = [1.0, 3.5, 2.8, 3.2]
    
    bars = plt.bar(methods, training_speeds, alpha=0.7, color='purple')
    plt.ylabel('Relative Training Speed')
    plt.title('Training Speed Comparison')
    
    # Add speedup labels
    for bar, speed in zip(bars, training_speeds):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05, 
                f'{speed:.1f}x', ha='center')
    
    plt.tight_layout()
    plt.show()
    
    print("\nPEFT Method Characteristics:")
    
    print("\nLoRA:")
    print("  + Very parameter efficient")
    print("  + No inference overhead")
    print("  + Easy to merge with base model")
    print("  - Limited to linear layers")
    
    print("\nAdapters:")
    print("  + Modular and swappable")
    print("  + Good performance retention")
    print("  - Adds inference latency")
    print("  - More parameters than LoRA")
    
    print("\nPrefix Tuning:")
    print("  + No architectural changes")
    print("  + Works with frozen models")
    print("  - Reduces effective context length")
    print("  - Task-specific optimization needed")

demonstrate_peft_methods()