# Focused Learning: Parameter-Efficient Fine-Tuning (PEFT) Methods
## Deep Dive into LoRA and Zero-init Attention Prefix-tuning

### Learning Objectives:
- Understand the mathematical foundations of LoRA and Prefix-tuning
- Implement both methods from scratch
- Compare their effectiveness for code review tasks
- Visualize how these methods modify model behavior

### Paper References:
- **Section III.C**: Zero-init Attention Prefix-tuning (Pages 3-4)
- **Section III.D**: Low-Rank Adaptation (Page 4)
- **Figure 4**: Details of prefix-tuning on LLaMA
- **Figure 5**: Core component of LoRA

## 1. Understanding PEFT: Why Do We Need It?

Large Language Models like LLaMA-7B have billions of parameters. Fine-tuning all parameters is:
- **Computationally expensive**: Requires massive GPU memory
- **Storage intensive**: Each fine-tuned model needs ~13GB storage
- **Inefficient for multi-task**: Need separate copies for each task

PEFT methods solve this by freezing most parameters and only training a small subset.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import seaborn as sns

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Helper function to count parameters
def count_parameters(model: nn.Module) -> Tuple[int, int]:
    """Count total and trainable parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

def print_parameter_stats(model: nn.Module, model_name: str):
    """Print parameter statistics"""
    total, trainable = count_parameters(model)
    percentage = (trainable / total) * 100 if total > 0 else 0
    print(f"\n{model_name}:")
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}")
    print(f"Trainable percentage: {percentage:.2f}%")

## 2. Low-Rank Adaptation (LoRA) - Mathematical Foundation

### Core Idea:
Instead of updating the full weight matrix $W_0 \in \mathbb{R}^{d \times k}$, LoRA approximates the weight update as:

$$W_0 + \Delta W = W_0 + W_{down} \cdot W_{up}$$

Where:
- $W_{down} \in \mathbb{R}^{d \times r}$
- $W_{up} \in \mathbb{R}^{r \times k}$
- $r \ll \min(d, k)$ (rank)

This reduces parameters from $d \times k$ to $r \times (d + k)$.

In [None]:
class LoRALayer(nn.Module):
    """Implementation of LoRA layer based on the paper"""
    
    def __init__(self, in_features: int, out_features: int, rank: int = 16, alpha: int = 16):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        # Frozen pretrained weight - not updated during training
        self.pretrained_weight = nn.Parameter(torch.randn(out_features, in_features))
        self.pretrained_weight.requires_grad = False
        
        # LoRA decomposition matrices - these are trained
        self.lora_down = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_up = nn.Parameter(torch.zeros(out_features, rank))
        
        # Initialize LoRA weights
        nn.init.kaiming_uniform_(self.lora_down, a=np.sqrt(5))
        nn.init.zeros_(self.lora_up)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Original computation
        h = x @ self.pretrained_weight.t()
        
        # LoRA computation: h + scaling * x @ W_down^T @ W_up^T
        lora_output = x @ self.lora_down.t() @ self.lora_up.t()
        
        return h + self.scaling * lora_output

# Demonstrate LoRA layer
input_dim, output_dim = 768, 768  # Typical transformer dimensions
batch_size, seq_len = 4, 128

# Create layers with different ranks
lora_r8 = LoRALayer(input_dim, output_dim, rank=8)
lora_r16 = LoRALayer(input_dim, output_dim, rank=16)
full_linear = nn.Linear(input_dim, output_dim)

print("Parameter Comparison:")
print_parameter_stats(full_linear, "Full Linear Layer")
print_parameter_stats(lora_r8, "LoRA (r=8)")
print_parameter_stats(lora_r16, "LoRA (r=16)")

# Visualize parameter reduction
models = ['Full Fine-tuning', 'LoRA (r=16)', 'LoRA (r=8)']
params = [input_dim * output_dim, 16 * (input_dim + output_dim), 8 * (input_dim + output_dim)]

plt.figure(figsize=(10, 6))
bars = plt.bar(models, params, color=['red', 'green', 'blue'])
plt.ylabel('Number of Trainable Parameters')
plt.title('LoRA Parameter Efficiency')
for bar, param in zip(bars, params):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{param:,}\n({param/params[0]*100:.1f}%)', 
             ha='center', va='bottom')
plt.tight_layout()
plt.show()

## 3. Zero-init Attention Prefix-tuning - Mathematical Foundation

### Core Idea:
Prefix-tuning adds learnable "soft prompts" (prefix tokens) to the input, which influence the attention mechanism.

The attention computation with prefix becomes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q[K_p; K]^T}{\sqrt{d}}\right)[V_p; V]$$

Where $K_p, V_p$ are learnable prefix keys and values.

The "zero-init" variant adds a learnable gating factor $g_l$ initialized to zero.

In [None]:
class ZeroInitPrefixAttention(nn.Module):
    """Implementation of Zero-init Attention Prefix-tuning based on the paper"""
    
    def __init__(self, embed_dim: int, num_heads: int, prefix_length: int = 10):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.prefix_length = prefix_length
        
        # Learnable prefix tokens
        self.prefix_tokens = nn.Parameter(torch.randn(prefix_length, embed_dim))
        
        # Projection layers for queries, keys, values (frozen in practice)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        
        # Freeze the projection layers
        for param in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
            for p in param.parameters():
                p.requires_grad = False
        
        # Zero-init gating factor (one per layer in practice)
        self.gating = nn.Parameter(torch.zeros(1))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        
        # Expand prefix tokens for batch
        prefix = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Concatenate prefix with input
        x_with_prefix = torch.cat([prefix, x], dim=1)
        
        # Compute Q, K, V
        Q = self.q_proj(x)  # Only compute Q for original tokens
        K = self.k_proj(x_with_prefix)
        V = self.v_proj(x_with_prefix)
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len + self.prefix_length, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len + self.prefix_length, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        
        # Split scores for prefix and original tokens
        prefix_scores = scores[..., :self.prefix_length]
        token_scores = scores[..., self.prefix_length:]
        
        # Apply gating to prefix scores
        gated_prefix_scores = prefix_scores * torch.sigmoid(self.gating)
        
        # Combine and apply softmax
        all_scores = torch.cat([gated_prefix_scores, token_scores], dim=-1)
        attn_weights = torch.softmax(all_scores, dim=-1)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)
        
        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        output = self.out_proj(attn_output)
        
        return output

# Demonstrate prefix attention
embed_dim = 768
num_heads = 12
prefix_attention = ZeroInitPrefixAttention(embed_dim, num_heads, prefix_length=10)

print("\nPrefix-tuning Parameter Analysis:")
print_parameter_stats(prefix_attention, "Zero-init Prefix Attention")

# Compare with standard attention
standard_attention = nn.MultiheadAttention(embed_dim, num_heads)
print_parameter_stats(standard_attention, "Standard Multi-head Attention")

## 4. Visualizing How PEFT Methods Work

Let's visualize how LoRA and Prefix-tuning modify the model's computations.

In [None]:
# Visualize LoRA decomposition
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Generate example matrices
d, k, r = 8, 8, 2
W0 = np.random.randn(d, k)
W_down = np.random.randn(d, r) * 0.1
W_up = np.random.randn(r, k) * 0.1
delta_W = W_down @ W_up

# Plot original weight
im1 = axes[0].imshow(W0, cmap='coolwarm', aspect='auto')
axes[0].set_title('Original Weight W₀\n(Frozen)', fontsize=14)
axes[0].set_xlabel('Input dimension')
axes[0].set_ylabel('Output dimension')
plt.colorbar(im1, ax=axes[0])

# Plot LoRA decomposition
im2 = axes[1].imshow(delta_W, cmap='coolwarm', aspect='auto')
axes[1].set_title('LoRA Update ΔW = W_down × W_up\n(Trainable)', fontsize=14)
axes[1].set_xlabel('Input dimension')
axes[1].set_ylabel('Output dimension')
plt.colorbar(im2, ax=axes[1])

# Add annotation showing decomposition
axes[1].text(k/2, d+1, f'Rank r={r}', ha='center', fontsize=12)

# Plot final weight
im3 = axes[2].imshow(W0 + delta_W, cmap='coolwarm', aspect='auto')
axes[2].set_title('Final Weight W₀ + ΔW', fontsize=14)
axes[2].set_xlabel('Input dimension')
axes[2].set_ylabel('Output dimension')
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.show()

# Visualize prefix-tuning attention pattern
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Standard attention pattern
seq_len = 10
attn_standard = np.random.softmax(np.random.randn(seq_len, seq_len), axis=-1)
im1 = ax1.imshow(attn_standard, cmap='Blues', aspect='auto')
ax1.set_title('Standard Attention', fontsize=14)
ax1.set_xlabel('Key/Value positions')
ax1.set_ylabel('Query positions')
plt.colorbar(im1, ax=ax1)

# Prefix-tuning attention pattern
prefix_len = 3
total_len = seq_len + prefix_len
attn_prefix = np.random.softmax(np.random.randn(seq_len, total_len), axis=-1)
# Highlight prefix region
attn_prefix[:, :prefix_len] *= 0.5  # Gating effect

im2 = ax2.imshow(attn_prefix, cmap='Blues', aspect='auto')
ax2.set_title('Prefix-tuning Attention', fontsize=14)
ax2.set_xlabel('Key/Value positions')
ax2.set_ylabel('Query positions')
ax2.axvline(x=prefix_len-0.5, color='red', linestyle='--', linewidth=2)
ax2.text(prefix_len/2, -0.5, 'Prefix', ha='center', color='red', fontsize=12)
ax2.text(prefix_len + (seq_len-prefix_len)/2, -0.5, 'Original', ha='center', fontsize=12)
plt.colorbar(im2, ax=ax2)

plt.tight_layout()
plt.show()

## 5. Comparing LoRA vs Prefix-tuning for Code Review

Based on the paper's findings, let's analyze why LoRA outperforms Prefix-tuning for code review tasks.

In [None]:
# Simulate performance comparison based on paper results
comparison_data = {
    'Method': ['Prefix-tuning', 'LoRA (r=8)', 'LoRA (r=16)'],
    'Review Comment Generation': [5.16, 5.64, 5.70],
    'Code Refinement': [76.71, 81.59, 82.27],
    'Trainable Params (M)': [1.2, 4.2, 8.4],
    'Storage (MB)': [2.4, 8.0, 16.0]
}

import pandas as pd
df = pd.DataFrame(comparison_data)

# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Performance comparison
ax1 = axes[0, 0]
x = np.arange(len(df['Method']))
width = 0.35
bars1 = ax1.bar(x - width/2, df['Review Comment Generation'], width, label='Comment Gen', color='skyblue')
bars2 = ax1.bar(x + width/2, df['Code Refinement']/10, width, label='Code Refine (/10)', color='lightcoral')
ax1.set_xlabel('Method')
ax1.set_ylabel('BLEU-4 Score')
ax1.set_title('Task Performance Comparison')
ax1.set_xticks(x)
ax1.set_xticklabels(df['Method'])
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# Parameter efficiency
ax2 = axes[0, 1]
ax2.bar(df['Method'], df['Trainable Params (M)'], color='green', alpha=0.7)
ax2.set_xlabel('Method')
ax2.set_ylabel('Trainable Parameters (Millions)')
ax2.set_title('Parameter Efficiency')
ax2.grid(axis='y', alpha=0.3)

# Analysis: Why LoRA performs better
ax3 = axes[1, 0]
ax3.axis('off')
analysis_text = """Why LoRA Outperforms Prefix-tuning:

1. Better Approximation:
   - LoRA directly modifies weight matrices
   - Prefix only influences attention patterns

2. Task Alignment:
   - Code review needs precise token generation
   - LoRA better captures fine-grained patterns

3. Flexibility:
   - LoRA rank can be adjusted (r=8, 16)
   - More parameters = better performance

4. Training Stability:
   - LoRA has smoother optimization landscape
   - Prefix-tuning's zero-init can be unstable"""
ax3.text(0.1, 0.9, analysis_text, transform=ax3.transAxes, 
         fontsize=11, verticalalignment='top', fontfamily='monospace')

# Practical considerations
ax4 = axes[1, 1]
ax4.axis('off')
practical_text = """Practical Considerations:

Choose LoRA when:
✓ Need best performance
✓ Have moderate GPU memory
✓ Multiple tasks to fine-tune

Choose Prefix-tuning when:
✓ Extremely limited memory
✓ Need minimal storage
✓ Simple classification tasks

Recommended: LoRA with r=16
- Best performance/efficiency balance
- Still <1% of full parameters"""
ax4.text(0.1, 0.9, practical_text, transform=ax4.transAxes,
         fontsize=11, verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

## 6. Implementing PEFT for Code Review Task

Let's implement a complete example of using LoRA for the review comment generation task.

In [None]:
class CodeReviewLoRAModel(nn.Module):
    """Simplified model demonstrating LoRA for code review"""
    
    def __init__(self, vocab_size: int = 50000, embed_dim: int = 768, 
                 num_layers: int = 4, lora_rank: int = 16):
        super().__init__()
        
        # Token embeddings (frozen in practice)
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.embeddings.weight.requires_grad = False
        
        # Transformer layers with LoRA
        self.layers = nn.ModuleList([
            self._create_layer_with_lora(embed_dim, lora_rank) 
            for _ in range(num_layers)
        ])
        
        # Output projection (with LoRA)
        self.output_proj = LoRALayer(embed_dim, vocab_size, rank=lora_rank)
    
    def _create_layer_with_lora(self, embed_dim: int, lora_rank: int) -> nn.Module:
        """Create a transformer layer with LoRA adapters"""
        return nn.ModuleDict({
            'self_attn_q': LoRALayer(embed_dim, embed_dim, rank=lora_rank),
            'self_attn_k': LoRALayer(embed_dim, embed_dim, rank=lora_rank),
            'self_attn_v': LoRALayer(embed_dim, embed_dim, rank=lora_rank),
            'self_attn_o': LoRALayer(embed_dim, embed_dim, rank=lora_rank),
            'layer_norm1': nn.LayerNorm(embed_dim),
            'layer_norm2': nn.LayerNorm(embed_dim),
            # FFN would also have LoRA in practice
        })
    
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # Simplified forward pass
        x = self.embeddings(input_ids)
        
        # Pass through layers (simplified - no actual attention)
        for layer in self.layers:
            # This is simplified - real implementation would compute attention
            q = layer['self_attn_q'](x)
            k = layer['self_attn_k'](x)
            v = layer['self_attn_v'](x)
            # ... attention computation ...
            x = layer['layer_norm1'](x)
        
        # Output projection
        logits = self.output_proj(x)
        return logits

# Create model and analyze parameters
model = CodeReviewLoRAModel(vocab_size=30000, embed_dim=768, num_layers=4, lora_rank=16)
print_parameter_stats(model, "Code Review Model with LoRA")

# Calculate memory savings
full_model_params = 30000 * 768 + 4 * 4 * 768 * 768 + 768 * 30000  # Approximate
lora_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
savings = (1 - lora_params / full_model_params) * 100

print(f"\nMemory Savings: {savings:.1f}%")
print(f"Storage per task: {lora_params * 4 / 1024 / 1024:.1f} MB (float32)")

## 7. Key Insights and Takeaways

### From the Paper's Experiments:

1. **LoRA Performance**:
   - Achieves 5.70 BLEU-4 on comment generation (best among all models)
   - Uses only 8.4M parameters vs 220M for baseline models
   - Storage: 16MB vs 850MB for full models

2. **Prefix-tuning Limitations**:
   - Lower performance (5.16 BLEU-4)
   - Incompatible with classification tasks
   - Less stable training with instruction tuning

3. **Rank Selection**:
   - r=16 provides best performance
   - r=8 offers better efficiency with slight performance drop
   - Higher ranks approach full fine-tuning performance

### Practical Implementation Tips:

1. **When to Use LoRA**:
   - Multi-task scenarios (different adapters per task)
   - Limited GPU memory
   - Need to preserve base model

2. **Optimization**:
   - Learning rate: 3e-4 for LoRA (vs 5e-5 for full fine-tuning)
   - Weight decay: 0.01
   - Batch size can be larger due to memory savings

3. **Integration**:
   - LoRA weights can be merged into base model for inference
   - Multiple LoRA adapters can be swapped dynamically
   - Compatible with quantization (8-bit, 4-bit)

In [None]:
# Final summary visualization
fig, ax = plt.subplots(figsize=(12, 8))

# Data from the paper
methods = ['Full\nFine-tuning', 'LoRA\n(r=16)', 'LoRA\n(r=8)', 'Prefix\nTuning']
performance = [100, 98, 95, 85]  # Relative performance
efficiency = [0, 99.87, 99.93, 99.98]  # Parameter reduction %
storage = [850, 16, 8, 2.4]  # MB

# Create bubble chart
colors = ['red', 'green', 'blue', 'orange']
for i, (method, perf, eff, stor) in enumerate(zip(methods, performance, efficiency, storage)):
    ax.scatter(eff, perf, s=stor*5, c=colors[i], alpha=0.6, edgecolors='black', linewidth=2)
    ax.annotate(method, (eff, perf), ha='center', va='center', fontsize=10, fontweight='bold')

ax.set_xlabel('Parameter Reduction (%)', fontsize=12)
ax.set_ylabel('Relative Performance (%)', fontsize=12)
ax.set_title('PEFT Methods: Performance vs Efficiency Trade-off\n(Bubble size = Storage in MB)', fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_xlim(-5, 105)
ax.set_ylim(80, 105)

# Add recommendation box
textstr = 'Paper Recommendation:\nLoRA with r=16 for best\nperformance-efficiency balance'
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', bbox=props)

plt.tight_layout()
plt.show()