# 11: LoRA from Scratch

## Learning Objectives

1. Understand why fine-tuning weight changes are low-rank
2. Implement LoRA (Low-Rank Adaptation) from first principles
3. Compare parameter counts: full fine-tuning vs LoRA
4. Train LoRA adapters on a classification task
5. Demonstrate merging LoRA weights for zero-overhead inference

**Prerequisites:** [fine-tuning](../modern-llms/fine-tuning.md), [efficient-adaptation](../modern-llms/efficient-adaptation.md)

**Framework:** PyTorch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import copy
from tqdm.auto import tqdm

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

## Part 1: The Intuition - Weight Changes Are Low-Rank

When we fine-tune a model, how much do the weights actually change? Let's simulate this.

In [None]:
def analyze_weight_change_rank(d_in, d_out, simulated_rank=8):
    """
    Demonstrate that fine-tuning changes often have low effective rank.
    
    In practice, fine-tuning updates W -> W + delta_W where delta_W is low-rank.
    """
    # Simulate a weight change that's inherently low-rank
    # (In real fine-tuning, this emerges from the optimization process)
    
    # True low-rank structure: delta_W = B @ A where B is d_out x r, A is r x d_in
    true_B = torch.randn(d_out, simulated_rank) * 0.1
    true_A = torch.randn(simulated_rank, d_in) * 0.1
    delta_W = true_B @ true_A  # This is our "weight change"
    
    # SVD to analyze the rank structure
    U, S, Vh = torch.linalg.svd(delta_W, full_matrices=False)
    
    # How much variance is captured by top-k singular values?
    total_variance = (S ** 2).sum().item()
    cumulative_variance = torch.cumsum(S ** 2, dim=0) / total_variance
    
    return S.numpy(), cumulative_variance.numpy()


# Analyze
d_in, d_out = 768, 768  # Typical transformer hidden dim
singular_values, cumvar = analyze_weight_change_rank(d_in, d_out, simulated_rank=8)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Singular values
axes[0].bar(range(min(50, len(singular_values))), singular_values[:50], color='steelblue')
axes[0].set_xlabel('Singular Value Index')
axes[0].set_ylabel('Singular Value')
axes[0].set_title('Singular Values of Weight Change (delta_W)\n(First 50 shown)')
axes[0].axvline(x=8, color='red', linestyle='--', label='True rank=8')
axes[0].legend()

# Cumulative variance
axes[1].plot(range(1, min(50, len(cumvar))+1), cumvar[:50], 'b-', linewidth=2)
axes[1].axhline(y=0.99, color='gray', linestyle='--', alpha=0.7, label='99% variance')
axes[1].axvline(x=8, color='red', linestyle='--', label='Rank 8')
axes[1].set_xlabel('Number of Components')
axes[1].set_ylabel('Cumulative Variance Explained')
axes[1].set_title('Variance Captured by Top-k Singular Values')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('lora_rank_intuition.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nKey insight: With true rank {8}, first 8 components capture {cumvar[7]*100:.1f}% of variance")
print("This means we can represent delta_W with far fewer parameters!")

## Part 2: Implementing LoRA from Scratch

LoRA replaces the weight update $\Delta W$ with a low-rank factorization:

$$W' = W + \Delta W = W + BA$$

Where:
- $W \in \mathbb{R}^{d_{out} \times d_{in}}$ is the frozen pretrained weight
- $B \in \mathbb{R}^{d_{out} \times r}$ is trainable
- $A \in \mathbb{R}^{r \times d_{in}}$ is trainable
- $r \ll \min(d_{out}, d_{in})$ is the rank

In [None]:
class LoRALinear(nn.Module):
    """
    Linear layer with LoRA adaptation.
    
    The original weight W is frozen. We learn low-rank matrices A and B
    such that the effective weight becomes W + (alpha/r) * B @ A.
    """
    
    def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
        super().__init__()
        
        self.original = original_linear
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        d_out, d_in = original_linear.weight.shape
        
        # Freeze original weights
        self.original.weight.requires_grad = False
        if self.original.bias is not None:
            self.original.bias.requires_grad = False
        
        # LoRA matrices
        # A is initialized with small random values
        # B is initialized with zeros so that delta_W starts at zero
        self.lora_A = nn.Parameter(torch.randn(rank, d_in) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(d_out, rank))
        
    def forward(self, x):
        # Original output (frozen path)
        base_output = self.original(x)
        
        # LoRA output (trainable path)
        # x @ A.T @ B.T = x @ (B @ A).T
        lora_output = F.linear(F.linear(x, self.lora_A), self.lora_B)
        
        return base_output + self.scaling * lora_output
    
    def merge_weights(self):
        """Merge LoRA weights into original for inference (no overhead)."""
        with torch.no_grad():
            # W' = W + scaling * B @ A
            merged = self.original.weight + self.scaling * (self.lora_B @ self.lora_A)
        return merged
    
    @property
    def num_trainable_params(self):
        return self.lora_A.numel() + self.lora_B.numel()
    
    @property
    def num_total_params(self):
        return self.original.weight.numel() + self.num_trainable_params


# Demonstrate parameter savings
d_in, d_out = 768, 768
rank = 8

original = nn.Linear(d_in, d_out, bias=False)
lora = LoRALinear(original, rank=rank)

full_params = d_in * d_out
lora_params = lora.num_trainable_params

print(f"Full fine-tuning parameters: {full_params:,}")
print(f"LoRA parameters (rank={rank}): {lora_params:,}")
print(f"Parameter reduction: {100 * (1 - lora_params / full_params):.2f}%")

In [None]:
# Verify that output starts the same (B initialized to zeros)
x = torch.randn(2, 10, d_in)  # [batch, seq, features]

with torch.no_grad():
    original_out = original(x)
    lora_out = lora(x)

print("Initial outputs match (B=0 means delta_W=0):")
print(f"  Max difference: {(original_out - lora_out).abs().max().item():.2e}")

## Part 3: Building a LoRA-Enabled Transformer

Let's create a small transformer and add LoRA to its attention projections.

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention without LoRA."""
    
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        B, T, C = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        # Attention
        scores = (Q @ K.transpose(-2, -1)) / (self.d_head ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        
        # Output
        out = (weights @ V).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=None, dropout=0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = FeedForward(d_model, dropout=dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        x = x + self.dropout(self.attn(self.ln1(x), mask))
        x = x + self.dropout(self.ffn(self.ln2(x)))
        return x


class SmallTransformer(nn.Module):
    """Small transformer for classification."""
    
    def __init__(self, vocab_size, d_model, n_heads, n_layers, num_classes, max_len=128, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout)
            for _ in range(n_layers)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, num_classes)
        
    def forward(self, input_ids, attention_mask=None):
        B, T = input_ids.shape
        
        # Embeddings
        pos = torch.arange(T, device=input_ids.device)
        x = self.dropout(self.tok_emb(input_ids) + self.pos_emb(pos))
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        x = self.ln_f(x)
        
        # Pool and classify (use last token)
        if attention_mask is not None:
            seq_lengths = attention_mask.sum(dim=1) - 1
            pooled = x[torch.arange(B, device=x.device), seq_lengths]
        else:
            pooled = x[:, -1]
        
        return self.classifier(pooled)


# Create model
vocab_size = 1000
d_model = 256
n_heads = 4
n_layers = 4
num_classes = 2

model = SmallTransformer(vocab_size, d_model, n_heads, n_layers, num_classes).to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
def add_lora_to_model(model, rank=8, alpha=16.0, target_modules=['W_q', 'W_v']):
    """
    Add LoRA adapters to specified modules in a model.
    
    Args:
        model: The model to modify
        rank: LoRA rank
        alpha: LoRA scaling factor
        target_modules: List of module names to add LoRA to
    
    Returns:
        Modified model with LoRA layers
    """
    lora_layers = []
    
    for name, module in model.named_modules():
        if isinstance(module, MultiHeadAttention):
            for target in target_modules:
                if hasattr(module, target):
                    original_linear = getattr(module, target)
                    lora_linear = LoRALinear(original_linear, rank=rank, alpha=alpha)
                    setattr(module, target, lora_linear)
                    lora_layers.append(lora_linear)
    
    return model, lora_layers


def count_parameters(model, only_trainable=False):
    """Count model parameters."""
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def freeze_non_lora_params(model):
    """Freeze all parameters except LoRA."""
    for name, param in model.named_parameters():
        if 'lora_' not in name:
            param.requires_grad = False


# Create a fresh model and add LoRA
lora_model = SmallTransformer(vocab_size, d_model, n_heads, n_layers, num_classes).to(device)

print("Before LoRA:")
print(f"  Total params: {count_parameters(lora_model):,}")
print(f"  Trainable params: {count_parameters(lora_model, only_trainable=True):,}")

# Add LoRA to Q and V projections
lora_model, lora_layers = add_lora_to_model(lora_model, rank=8, alpha=16.0, target_modules=['W_q', 'W_v'])
freeze_non_lora_params(lora_model)

print("\nAfter LoRA (rank=8, Q and V only):")
print(f"  Total params: {count_parameters(lora_model):,}")
print(f"  Trainable params: {count_parameters(lora_model, only_trainable=True):,}")
print(f"  LoRA layers added: {len(lora_layers)}")

# Compare to full fine-tuning
full_model = SmallTransformer(vocab_size, d_model, n_heads, n_layers, num_classes).to(device)
print(f"\nFull fine-tuning would train: {count_parameters(full_model):,} params")
print(f"LoRA trains only: {count_parameters(lora_model, only_trainable=True):,} params")
print(f"Reduction: {100 * (1 - count_parameters(lora_model, only_trainable=True) / count_parameters(full_model)):.1f}%")

## Part 4: Training LoRA vs Full Fine-Tuning

Let's compare LoRA and full fine-tuning on a simple classification task.

In [None]:
# Create a synthetic classification dataset
def create_synthetic_data(num_samples, vocab_size, seq_len, num_classes=2):
    """
    Create synthetic data where class is determined by token patterns.
    Class 0: sequences with more low-id tokens (0-vocab_size/2)
    Class 1: sequences with more high-id tokens (vocab_size/2-vocab_size)
    """
    data = []
    mid = vocab_size // 2
    
    for _ in range(num_samples):
        # Decide class
        label = np.random.randint(0, num_classes)
        
        if label == 0:
            # More low tokens
            tokens = np.random.randint(0, mid, size=seq_len)
            # Sprinkle some high tokens
            noise_idx = np.random.choice(seq_len, size=seq_len//4, replace=False)
            tokens[noise_idx] = np.random.randint(mid, vocab_size, size=len(noise_idx))
        else:
            # More high tokens
            tokens = np.random.randint(mid, vocab_size, size=seq_len)
            # Sprinkle some low tokens
            noise_idx = np.random.choice(seq_len, size=seq_len//4, replace=False)
            tokens[noise_idx] = np.random.randint(0, mid, size=len(noise_idx))
        
        data.append((tokens.tolist(), label))
    
    return data


class SyntheticDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens, label = self.data[idx]
        return {
            'input_ids': torch.tensor(tokens, dtype=torch.long),
            'label': torch.tensor(label, dtype=torch.long)
        }


# Create datasets
seq_len = 32
train_data = create_synthetic_data(1000, vocab_size, seq_len)
test_data = create_synthetic_data(200, vocab_size, seq_len)

train_dataset = SyntheticDataset(train_data)
test_dataset = SyntheticDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
def train_model(model, train_loader, test_loader, epochs=10, lr=1e-4, description=""):
    """Train a model and track metrics."""
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=lr,
        weight_decay=0.01
    )
    
    metrics = {'train_loss': [], 'train_acc': [], 'test_acc': []}
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            logits = model(input_ids)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        
        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        
        # Evaluation
        model.train(False)
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in test_loader:
                input_ids = batch['input_ids'].to(device)
                labels = batch['label'].to(device)
                
                logits = model(input_ids)
                preds = logits.argmax(dim=-1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        
        test_acc = correct / total
        
        metrics['train_loss'].append(train_loss)
        metrics['train_acc'].append(train_acc)
        metrics['test_acc'].append(test_acc)
        
        if (epoch + 1) % 2 == 0:
            print(f"{description} Epoch {epoch+1}: Loss={train_loss:.4f}, Train Acc={train_acc:.3f}, Test Acc={test_acc:.3f}")
    
    return metrics

In [None]:
# Train with full fine-tuning
print("=" * 60)
print("FULL FINE-TUNING")
print("=" * 60)

full_model = SmallTransformer(vocab_size, d_model, n_heads, n_layers, num_classes).to(device)
full_trainable = count_parameters(full_model, only_trainable=True)
print(f"Trainable parameters: {full_trainable:,}")

full_metrics = train_model(full_model, train_loader, test_loader, epochs=10, lr=1e-4, description="[Full]")

In [None]:
# Train with LoRA
print("\n" + "=" * 60)
print("LORA FINE-TUNING (rank=8)")
print("=" * 60)

lora_model = SmallTransformer(vocab_size, d_model, n_heads, n_layers, num_classes).to(device)
lora_model, _ = add_lora_to_model(lora_model, rank=8, alpha=16.0, target_modules=['W_q', 'W_v'])
freeze_non_lora_params(lora_model)

lora_trainable = count_parameters(lora_model, only_trainable=True)
print(f"Trainable parameters: {lora_trainable:,} ({100*lora_trainable/full_trainable:.1f}% of full)")

lora_metrics = train_model(lora_model, train_loader, test_loader, epochs=10, lr=1e-3, description="[LoRA]")

In [None]:
# Train with different LoRA ranks
print("\n" + "=" * 60)
print("LORA RANK COMPARISON")
print("=" * 60)

rank_results = {}

for rank in [2, 4, 8, 16, 32]:
    model = SmallTransformer(vocab_size, d_model, n_heads, n_layers, num_classes).to(device)
    model, _ = add_lora_to_model(model, rank=rank, alpha=2*rank, target_modules=['W_q', 'W_v'])
    freeze_non_lora_params(model)
    
    trainable = count_parameters(model, only_trainable=True)
    metrics = train_model(model, train_loader, test_loader, epochs=10, lr=1e-3, description=f"[r={rank}]")
    
    rank_results[rank] = {
        'params': trainable,
        'final_test_acc': metrics['test_acc'][-1],
        'metrics': metrics
    }
    print(f"Rank {rank}: {trainable:,} params, Final Test Acc: {metrics['test_acc'][-1]:.3f}\n")

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

epochs = range(1, 11)

# Test accuracy over training
axes[0].plot(epochs, full_metrics['test_acc'], 'k-', label='Full fine-tune', linewidth=2)
axes[0].plot(epochs, lora_metrics['test_acc'], 'b-', label='LoRA (r=8)', linewidth=2)
for rank, data in rank_results.items():
    if rank != 8:
        axes[0].plot(epochs, data['metrics']['test_acc'], '--', alpha=0.6, label=f'LoRA (r={rank})')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Test Accuracy')
axes[0].set_title('Learning Curves')
axes[0].legend(loc='lower right')
axes[0].grid(True, alpha=0.3)

# Final accuracy vs parameters
ranks = list(rank_results.keys())
params = [rank_results[r]['params'] for r in ranks]
accs = [rank_results[r]['final_test_acc'] for r in ranks]

axes[1].scatter(params, accs, s=100, c='steelblue', zorder=5)
for r, p, a in zip(ranks, params, accs):
    axes[1].annotate(f'r={r}', (p, a), textcoords="offset points", xytext=(0,10), ha='center')

# Add full fine-tuning point
axes[1].scatter([full_trainable], [full_metrics['test_acc'][-1]], s=150, c='red', marker='*', zorder=5, label='Full fine-tune')
axes[1].annotate('Full', (full_trainable, full_metrics['test_acc'][-1]), textcoords="offset points", xytext=(0,10), ha='center')

axes[1].set_xlabel('Trainable Parameters')
axes[1].set_ylabel('Final Test Accuracy')
axes[1].set_title('Accuracy vs Parameter Count')
axes[1].set_xscale('log')
axes[1].grid(True, alpha=0.3)

# Parameter comparison bar chart
methods = ['Full'] + [f'LoRA\nr={r}' for r in ranks]
param_counts = [full_trainable] + params
colors = ['red'] + ['steelblue'] * len(ranks)

bars = axes[2].bar(methods, param_counts, color=colors)
axes[2].set_ylabel('Trainable Parameters')
axes[2].set_title('Parameter Count Comparison')
axes[2].set_yscale('log')

# Add text labels
for bar, count in zip(bars, param_counts):
    axes[2].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                 f'{count:,}', ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.savefig('lora_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## Part 5: Merging LoRA Weights

One key advantage of LoRA: you can merge the adapters into the base weights for zero-overhead inference.

In [None]:
def merge_lora_weights(model):
    """
    Merge LoRA weights into the original weights.
    After merging, the model can run without LoRA overhead.
    """
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            # Merge: W' = W + scaling * B @ A
            merged_weight = module.merge_weights()
            module.original.weight.data = merged_weight
            
            # Reset LoRA matrices to zero (effectively disabling them)
            module.lora_A.data.zero_()
            module.lora_B.data.zero_()
    
    return model


# Demonstrate merging
print("Before merging:")
sample_input = torch.randint(0, vocab_size, (2, seq_len)).to(device)

lora_model.train(False)
with torch.no_grad():
    output_before = lora_model(sample_input)

# Check a LoRA layer
for name, module in lora_model.named_modules():
    if isinstance(module, LoRALinear):
        print(f"  LoRA A norm: {module.lora_A.norm().item():.4f}")
        print(f"  LoRA B norm: {module.lora_B.norm().item():.4f}")
        break

# Merge
print("\nMerging LoRA weights...")
merge_lora_weights(lora_model)

print("\nAfter merging:")
with torch.no_grad():
    output_after = lora_model(sample_input)

for name, module in lora_model.named_modules():
    if isinstance(module, LoRALinear):
        print(f"  LoRA A norm: {module.lora_A.norm().item():.4f}")
        print(f"  LoRA B norm: {module.lora_B.norm().item():.4f}")
        break

# Verify outputs match
diff = (output_before - output_after).abs().max().item()
print(f"\nOutput difference after merge: {diff:.2e}")
print("(Should be ~0, meaning merge preserved behavior)")

## Part 6: LoRA Hyperparameters

Key hyperparameters and their effects:

In [None]:
def visualize_lora_hyperparameters():
    """Visualize the effect of LoRA hyperparameters."""
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Rank vs Parameters
    d = 256  # Hidden dimension
    ranks = [1, 2, 4, 8, 16, 32, 64, 128]
    
    # For Q and V projections in 4 layers
    lora_params = [2 * 4 * (d * r + r * d) for r in ranks]
    full_params = 2 * 4 * d * d  # Just Q and V
    
    axes[0].plot(ranks, lora_params, 'bo-', linewidth=2, markersize=8)
    axes[0].axhline(y=full_params, color='red', linestyle='--', label=f'Full fine-tune Q+V ({full_params:,})')
    axes[0].set_xlabel('LoRA Rank')
    axes[0].set_ylabel('Trainable Parameters')
    axes[0].set_title('Parameters vs LoRA Rank')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    axes[0].set_xscale('log', base=2)
    axes[0].set_yscale('log')
    
    # Alpha/Rank ratio (scaling)
    ranks_demo = [4, 8, 16]
    alphas = [4, 8, 16, 32, 64]
    
    for r in ranks_demo:
        scalings = [a / r for a in alphas]
        axes[1].plot(alphas, scalings, 'o-', label=f'rank={r}')
    
    axes[1].set_xlabel('Alpha')
    axes[1].set_ylabel('Scaling Factor (alpha/rank)')
    axes[1].set_title('LoRA Scaling Factor')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    axes[1].axhline(y=1.0, color='gray', linestyle=':', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig('lora_hyperparameters.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("LoRA Hyperparameters:")
    print("="*50)
    print("Rank (r):")
    print("  - Higher = more capacity, more parameters")
    print("  - Typical values: 4, 8, 16, 32, 64")
    print("  - Start with 8, increase if underfitting")
    print("")
    print("Alpha (a):")
    print("  - Controls magnitude of LoRA update")
    print("  - Effective scaling = alpha / rank")
    print("  - Common practice: alpha = 2 * rank")
    print("")
    print("Target modules:")
    print("  - Q, V only: minimal parameters")
    print("  - Q, K, V: more capacity")
    print("  - Q, K, V, O: full attention adaptation")
    print("  - + FFN: maximum capacity")

visualize_lora_hyperparameters()

## Part 7: Summary

Key takeaways about LoRA:

In [None]:
# Final summary
print("""
LoRA: Low-Rank Adaptation
=========================

Core Idea:
  W' = W + (alpha/r) * B @ A
  
  - W: Frozen pretrained weight [d_out, d_in]
  - B: Trainable [d_out, r]
  - A: Trainable [r, d_in]
  - r: Low rank (typically 4-64)

Key Benefits:
  1. Parameter efficiency: ~0.1-1% of full fine-tuning
  2. Memory efficiency: No optimizer states for frozen params
  3. Zero inference overhead: Merge weights after training
  4. Task switching: Swap adapters without reloading base model

When to Use:
  - Memory constrained: Can't fit full fine-tuning
  - Multiple tasks: Need different adaptations of same base
  - Fast iteration: Train adapters quickly
  - Production: Need efficient inference

Common Configuration:
  - Rank: 8 (start here, increase if needed)
  - Alpha: 16 (or 2 * rank)
  - Targets: Q and V projections minimum
  - Learning rate: 1e-4 to 3e-4 (higher than full FT)

Historical Note:
  LoRA (Hu et al., 2021) showed that weight updates during
  fine-tuning have low intrinsic rank, enabling dramatic
  parameter reduction with minimal quality loss.
""")

## Exercises

1. **Different target modules**: Try adding LoRA to K and O projections. Does performance improve?

2. **FFN adaptation**: Add LoRA to the feedforward layers. Compare to attention-only LoRA.

3. **Multi-task LoRA**: Train separate LoRA adapters for two different tasks. Show how to swap them.

4. **Rank analysis**: After training, SVD the merged weight change. Is it actually low-rank?

5. **QLoRA simulation**: Quantize the base weights to int8, keep LoRA in float32. Compare accuracy.

## Summary

| Concept | Key Point |
|---------|----------|
| Low-rank insight | Fine-tuning changes live in a low-dimensional subspace |
| LoRA formula | W' = W + (alpha/r) * B @ A |
| Parameter savings | 99%+ reduction vs full fine-tuning |
| Initialization | A random small, B zeros (start with no change) |
| Merging | Add B @ A to W for zero inference overhead |
| Target modules | Q and V minimum, add more for capacity |

**Key insight:** Full fine-tuning is wastefulâ€”the actual task-specific changes occupy a tiny subspace of the full parameter space. LoRA parameterizes exactly this subspace with a low-rank factorization, achieving nearly identical results with a fraction of the parameters.