# Focused Learning 01: Parameter-Efficient Fine-Tuning (PEFT) Methods

## 🎯 Objective
Notebook này giải thích sâu về **Parameter-Efficient Fine-Tuning (PEFT)** methods được sử dụng trong LLaMA-Reviewer, đặc biệt là **LoRA** và **Prefix-tuning**.

## 📍 Paper Reference
- **Section III.C**: Zero-init Attention Prefix-tuning
- **Section III.D**: Low-Rank Adaptation (LoRA)
- **Figure 4**: Details of prefix-tuning on LLaMA
- **Figure 5**: Core component of Low-Rank Adaptation
- **Equations 1-6**: Mathematical formulations

## 🔍 Core Problem
Fine-tuning large language models như LLaMA (6.7B+ parameters) đòi hỏi:
- **Computational Resources**: Cần GPU memory lớn
- **Storage Space**: Mỗi fine-tuned model cần ~13GB storage
- **Training Time**: Rất lâu để update toàn bộ parameters

**PEFT Solution**: Chỉ fine-tune một số nhỏ parameters (<1%) trong khi freeze base model weights.

## 🧮 Mathematical Foundations

### Core Hypothesis
PEFT dựa trên giả thuyết rằng quá trình adaptation có **"intrinsic rank"** thấp, nghĩa là có thể biểu diễn thay đổi weights thông qua low-rank matrices.

In [None]:
# Setup environment
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional
import math

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("🔧 Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 🎯 Method 1: Low-Rank Adaptation (LoRA)

### Mathematical Foundation (Equation 5-6 from paper)

LoRA approximates weight updates thông qua low-rank decomposition:

$$W_0 + \Delta W = W_0 + W_{down} W_{up}$$

Where:
- $W_0 \in \mathbb{R}^{d \times k}$: Pre-trained weight matrix (frozen)
- $W_{down} \in \mathbb{R}^{d \times r}$: Trainable down-projection
- $W_{up} \in \mathbb{R}^{r \times k}$: Trainable up-projection  
- $r \ll \min(d, k)$: Low rank

Output computation:
$$\bar{h} = W_0 x + \Delta W x = h + W_{down} W_{up} x$$

In [None]:
class LoRALayer(nn.Module):
    """Implementation of LoRA layer based on paper equations"""
    
    def __init__(self, original_layer: nn.Linear, rank: int, alpha: float = 1.0):
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        
        # Freeze original weights (W_0)
        for param in self.original_layer.parameters():
            param.requires_grad = False
        
        # LoRA matrices (Equation 5)
        self.lora_A = nn.Linear(original_layer.in_features, rank, bias=False)  # W_down
        self.lora_B = nn.Linear(rank, original_layer.out_features, bias=False)  # W_up
        
        # Initialize following paper recommendations
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)
        
        # Scaling factor
        self.scaling = alpha / rank
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass implementing Equation 6"""
        # Original output: h = W_0 * x
        original_output = self.original_layer(x)
        
        # LoRA output: W_down * W_up * x
        lora_output = self.lora_B(self.lora_A(x)) * self.scaling
        
        # Combined output: h + ΔW * x (Equation 6)
        return original_output + lora_output
    
    def get_delta_weights(self) -> torch.Tensor:
        """Get the ΔW matrix for analysis"""
        return (self.lora_B.weight @ self.lora_A.weight) * self.scaling

print("✅ LoRA implementation complete!")

In [None]:
# Demonstrate LoRA with mock data
def demonstrate_lora():
    """Demonstrate LoRA layer behavior"""
    
    # Create original layer (simulating part of transformer)
    d, k = 768, 768  # Typical transformer dimensions
    original_layer = nn.Linear(d, k)
    
    # Create LoRA layers with different ranks
    ranks = [4, 8, 16, 32]
    lora_layers = {}
    
    for rank in ranks:
        lora_layers[rank] = LoRALayer(original_layer, rank=rank, alpha=16)
    
    # Test input
    batch_size, seq_len = 4, 128
    x = torch.randn(batch_size, seq_len, d)
    
    print("🔍 LoRA Analysis:")
    print(f"Original layer parameters: {sum(p.numel() for p in original_layer.parameters()):,}")
    
    results = {}
    for rank in ranks:
        layer = lora_layers[rank]
        
        # Count trainable parameters
        trainable_params = sum(p.numel() for p in layer.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in layer.parameters())
        
        # Forward pass
        output = layer(x)
        
        # Analysis
        delta_w = layer.get_delta_weights()
        
        results[rank] = {
            'trainable_params': trainable_params,
            'total_params': total_params,
            'percentage': trainable_params / total_params * 100,
            'delta_norm': torch.norm(delta_w).item(),
            'output_shape': output.shape
        }
        
        print(f"\nRank {rank}:")
        print(f"  Trainable: {trainable_params:,} ({results[rank]['percentage']:.3f}%)")
        print(f"  ΔW norm: {results[rank]['delta_norm']:.4f}")
    
    return results

# Run demonstration
lora_results = demonstrate_lora()

## 🎯 Method 2: Zero-init Attention Prefix-tuning

### Mathematical Foundation (Equations 1-4 from paper)

Prefix-tuning thêm K learnable prefix tokens vào L top layers của transformer.

#### Attention Computation (Equation 1):
$$Q_l = Linear_q(t_{M+1})$$
$$K_l = Linear_k([P_l; T_l; t_{M+1}])$$
$$V_l = Linear_v([P_l; T_l; t_{M+1}])$$

Where $P_l \in \mathbb{R}^{K \times C}$ là prefix tokens, $T_l \in \mathbb{R}^{M \times C}$ là original tokens.

#### Attention Scores (Equation 2):
$$S_l^K = Q_l (K_l^K)^T / \sqrt{C}$$
$$S_l^{M+1} = Q_l (K_l^{M+1})^T / \sqrt{C}$$

#### Zero-init Gating (Equation 3):
$$S_l^g = [Softmax(S_l^K) \cdot g_l; Softmax(S_l^{M+1})]$$

Where $g_l$ is learnable gating factor, initialized to 0.

In [None]:
class ZeroInitPrefixTuning(nn.Module):
    """Implementation of Zero-init Attention Prefix-tuning"""
    
    def __init__(self, 
                 hidden_size: int,
                 num_heads: int,
                 prefix_length: int,
                 num_layers: int):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.prefix_length = prefix_length  # K in paper
        self.num_layers = num_layers  # L in paper
        self.head_dim = hidden_size // num_heads
        
        # Prefix tokens for each layer (P_l in Equation 1)
        self.prefix_tokens = nn.ParameterList([
            nn.Parameter(torch.randn(prefix_length, hidden_size))
            for _ in range(num_layers)
        ])
        
        # Zero-init gating factors (g_l in Equation 3)
        self.gating_factors = nn.ParameterList([
            nn.Parameter(torch.zeros(num_heads))  # Zero initialization
            for _ in range(num_layers)
        ])
        
        # Linear projections
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor:
        """Forward pass implementing Equations 1-4"""
        batch_size, seq_len, hidden_size = x.shape
        
        # Get prefix tokens for this layer
        prefix = self.prefix_tokens[layer_idx].unsqueeze(0).expand(batch_size, -1, -1)
        
        # Concatenate prefix with original tokens (Equation 1)
        # [P_l; T_l; t_{M+1}] -> [prefix_len + seq_len, hidden_size]
        extended_input = torch.cat([prefix, x], dim=1)
        
        # Apply projections (Equation 1)
        Q = self.q_proj(x)  # Only query from original tokens
        K = self.k_proj(extended_input)  # Keys from prefix + original
        V = self.v_proj(extended_input)  # Values from prefix + original
        
        # Reshape for multi-head attention
        Q = Q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores (Equation 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Split scores for prefix and original tokens
        prefix_scores = scores[:, :, :, :self.prefix_length]  # S_l^K
        original_scores = scores[:, :, :, self.prefix_length:]  # S_l^{M+1}
        
        # Apply zero-init gating (Equation 3)
        gating = self.gating_factors[layer_idx].view(1, self.num_heads, 1, 1)
        
        prefix_attn = torch.softmax(prefix_scores, dim=-1) * gating
        original_attn = torch.softmax(original_scores, dim=-1)
        
        # Combine attention weights
        combined_attn = torch.cat([prefix_attn, original_attn], dim=-1)
        
        # Apply attention to values (Equation 4)
        output = torch.matmul(combined_attn, V)
        
        # Reshape and project
        output = output.transpose(1, 2).reshape(batch_size, seq_len, hidden_size)
        output = self.out_proj(output)
        
        return output
    
    def get_trainable_params(self) -> int:
        """Count trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

print("✅ Zero-init Prefix-tuning implementation complete!")

In [None]:
# Demonstrate Prefix-tuning
def demonstrate_prefix_tuning():
    """Demonstrate prefix-tuning behavior"""
    
    # Configuration (similar to paper)
    hidden_size = 768
    num_heads = 12
    num_layers = 12
    prefix_lengths = [5, 10, 20]  # Different prefix lengths to test
    
    # Test input
    batch_size, seq_len = 4, 64
    x = torch.randn(batch_size, seq_len, hidden_size)
    
    print("🔍 Prefix-tuning Analysis:")
    
    results = {}
    for prefix_len in prefix_lengths:
        model = ZeroInitPrefixTuning(
            hidden_size=hidden_size,
            num_heads=num_heads,
            prefix_length=prefix_len,
            num_layers=num_layers
        )
        
        # Forward pass through first layer
        output = model(x, layer_idx=0)
        
        # Analysis
        trainable_params = model.get_trainable_params()
        
        # Check gating factor (should be close to 0 initially)
        initial_gating = model.gating_factors[0].mean().item()
        
        results[prefix_len] = {
            'trainable_params': trainable_params,
            'initial_gating': initial_gating,
            'output_shape': output.shape,
            'prefix_contribution': torch.norm(output - x).item()  # How much prefix changes output
        }
        
        print(f"\nPrefix Length {prefix_len}:")
        print(f"  Trainable params: {trainable_params:,}")
        print(f"  Initial gating: {initial_gating:.6f}")
        print(f"  Output change norm: {results[prefix_len]['prefix_contribution']:.4f}")
    
    return results

# Run demonstration
prefix_results = demonstrate_prefix_tuning()

## 📊 Comparative Analysis

So sánh LoRA vs Prefix-tuning theo findings từ paper

In [None]:
# Comparative analysis visualization
def compare_peft_methods():
    """Compare LoRA vs Prefix-tuning characteristics"""
    
    # Paper results reproduction (Table IX)
    paper_comparison = {
        'Method': ['Prefix-tuning', 'LoRA (r=8)', 'LoRA (r=16)'],
        'Trainable_Params_M': [1.2, 4.2, 8.4],  # Million parameters
        'Storage_MB': [2.4, 8, 16],
        'RNP_F1': [None, 69.34, 70.49],  # Prefix-tuning couldn't do classification
        'RCG_BLEU': [5.16, 5.64, 5.70],
        'CR_BLEU': [76.71, 81.59, 82.27]
    }
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('PEFT Methods Comparison: LoRA vs Prefix-tuning', fontsize=16, fontweight='bold')
    
    # 1. Parameter Efficiency
    ax1 = axes[0, 0]
    methods = paper_comparison['Method']
    params = paper_comparison['Trainable_Params_M']
    
    bars = ax1.bar(methods, params, color=['lightcoral', 'lightblue', 'lightgreen'])
    ax1.set_title('Trainable Parameters (Millions)')
    ax1.set_ylabel('Parameters (M)')
    ax1.tick_params(axis='x', rotation=45)
    
    # Add value labels
    for bar, value in zip(bars, params):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                f'{value}M', ha='center', va='bottom')
    
    # 2. Storage Requirements
    ax2 = axes[0, 1]
    storage = paper_comparison['Storage_MB']
    
    bars = ax2.bar(methods, storage, color=['lightcoral', 'lightblue', 'lightgreen'])
    ax2.set_title('Storage Requirements (MB)')
    ax2.set_ylabel('Storage (MB)')
    ax2.tick_params(axis='x', rotation=45)
    
    for bar, value in zip(bars, storage):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
                f'{value}MB', ha='center', va='bottom')
    
    # 3. Performance - Review Comment Generation
    ax3 = axes[1, 0]
    rcg_scores = paper_comparison['RCG_BLEU']
    
    bars = ax3.bar(methods, rcg_scores, color=['lightcoral', 'lightblue', 'lightgreen'])
    ax3.set_title('Review Comment Generation (BLEU-4)')
    ax3.set_ylabel('BLEU-4 Score')
    ax3.tick_params(axis='x', rotation=45)
    
    for bar, value in zip(bars, rcg_scores):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05,
                f'{value:.2f}', ha='center', va='bottom')
    
    # 4. Performance - Code Refinement
    ax4 = axes[1, 1]
    cr_scores = paper_comparison['CR_BLEU']
    
    bars = ax4.bar(methods, cr_scores, color=['lightcoral', 'lightblue', 'lightgreen'])
    ax4.set_title('Code Refinement (BLEU-4)')
    ax4.set_ylabel('BLEU-4 Score')
    ax4.tick_params(axis='x', rotation=45)
    
    for bar, value in zip(bars, cr_scores):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{value:.1f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    # Print key insights
    print("\n🔍 Key Insights from Paper (RQ4):")
    print("=" * 50)
    print("1. LoRA consistently outperforms Prefix-tuning across all tasks")
    print("2. Higher LoRA rank (r=16) achieves better performance than r=8")
    print("3. Prefix-tuning struggles with classification tasks (RNP)")
    print("4. LoRA provides better approximation to full-parameter tuning")
    print("5. Trade-off: Higher rank = more parameters but better performance")
    
    return paper_comparison

# Run comparison
comparison_data = compare_peft_methods()

## 🧪 Hands-on Experimentation

Thực nghiệm với mock data để hiểu behavior của từng method

In [None]:
# Hands-on experimentation
def experiment_with_peft():
    """Experiment to understand PEFT behavior"""
    
    print("🧪 PEFT Experimentation")
    print("=" * 40)
    
    # Create a simple "pre-trained" layer
    original_layer = nn.Linear(256, 256)
    
    # Create mock task: transform input in a specific way
    # Target: multiply input by 2 (simple adaptation task)
    target_transform = lambda x: x * 2
    
    # Test data
    test_input = torch.randn(10, 256)
    target_output = target_transform(test_input)
    
    # Method 1: LoRA adaptation
    print("\n🎯 Testing LoRA adaptation:")
    lora_layer = LoRALayer(original_layer, rank=8, alpha=16)
    
    # Simple "training" - just to show adaptation
    optimizer = torch.optim.AdamW(lora_layer.parameters(), lr=0.01)
    
    initial_loss = None
    for epoch in range(100):
        optimizer.zero_grad()
        output = lora_layer(test_input)
        loss = nn.MSELoss()(output, target_output)
        
        if epoch == 0:
            initial_loss = loss.item()
        
        loss.backward()
        optimizer.step()
        
        if epoch % 20 == 0:
            print(f"  Epoch {epoch}: Loss = {loss.item():.4f}")
    
    final_loss = loss.item()
    improvement = (initial_loss - final_loss) / initial_loss * 100
    
    print(f"  LoRA adaptation: {improvement:.1f}% loss reduction")
    print(f"  Trainable params: {sum(p.numel() for p in lora_layer.parameters() if p.requires_grad):,}")
    
    # Method 2: Full fine-tuning comparison
    print("\n🎯 Testing full fine-tuning (baseline):")
    full_layer = nn.Linear(256, 256)
    full_layer.load_state_dict(original_layer.state_dict())  # Start from same weights
    
    optimizer_full = torch.optim.AdamW(full_layer.parameters(), lr=0.01)
    
    for epoch in range(100):
        optimizer_full.zero_grad()
        output = full_layer(test_input)
        loss = nn.MSELoss()(output, target_output)
        
        loss.backward()
        optimizer_full.step()
        
        if epoch % 20 == 0:
            print(f"  Epoch {epoch}: Loss = {loss.item():.4f}")
    
    print(f"  Full fine-tuning final loss: {loss.item():.4f}")
    print(f"  Trainable params: {sum(p.numel() for p in full_layer.parameters()):,}")
    
    # Analysis
    lora_params = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)
    full_params = sum(p.numel() for p in full_layer.parameters())
    
    print(f"\n📊 Efficiency Analysis:")
    print(f"  Parameter reduction: {(1 - lora_params/full_params)*100:.1f}%")
    print(f"  LoRA final loss: {final_loss:.4f}")
    print(f"  Full fine-tuning final loss: {loss.item():.4f}")
    print(f"  Performance gap: {abs(final_loss - loss.item()):.4f}")

# Run experiment
experiment_with_peft()

## 🎯 Key Takeaways

### From Paper Analysis (Section V.D):

1. **LoRA Superior Performance**: LoRA outperforms prefix-tuning across all code review tasks
   - Better approximation to full-parameter tuning
   - More flexible adaptation mechanism

2. **Rank Selection Impact**: Higher LoRA rank improves performance
   - r=16 > r=8 consistently
   - Trade-off between efficiency and performance

3. **Task-Specific Behavior**: 
   - Prefix-tuning struggles with classification (RNP)
   - Both methods effective for generation tasks (RCG, CR)

4. **Efficiency Gains**: Both methods achieve <1% trainable parameters
   - 99%+ parameter reduction vs full fine-tuning
   - Significant storage savings (13GB → <20MB)

### Implementation Insights:

- **LoRA**: Low-rank approximation captures essential adaptations
- **Prefix-tuning**: Zero-init gating provides stable training
- **Both**: Enable efficient deployment of specialized models

### Practical Applications:

- **Multi-task deployment**: Different PEFT adapters for different tasks
- **Resource constraints**: Minimal memory and storage requirements
- **Rapid adaptation**: Quick fine-tuning for new domains/tasks