# 04 - SLM Training with LoRA

**Previous:** [03_LLM_Evaluation_ZeroShot.ipynb](03_LLM_Evaluation_ZeroShot.ipynb)  
**Next:** [05_SLM_Evaluation_Finetuned.ipynb](05_SLM_Evaluation_Finetuned.ipynb)

---

## What This Notebook Covers

This is the **heart of our project** - training small language models (3B parameters) to become medical diagnosis specialists!

**Key Questions:**
1. What is LoRA and why is it revolutionary?
2. How does LoRA reduce training costs by 95%+?
3. How do we set up and configure LoRA?
4. What happens during the training loop?
5. How do we monitor and interpret training progress?

**Models We'll Train:**
- **Llama 3.2 3B** (Meta's efficient small model)
- **Qwen 2.5 3B** (Alibaba's competitive alternative)

**Why This Matters:**
- Transforms general models into medical specialists
- Tests if specialization can beat size
- Practical for real-world deployment

---

## Setup

In [None]:
import os
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Critical for GPU memory management
os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True'

# Add src to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / "src"))

print(f"‚úÖ Project Root: {project_root}")

In [None]:
# Import libraries
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
from datasets import load_dataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict
from tqdm.auto import tqdm
import gc

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("‚úÖ All libraries imported")

In [None]:
# Check GPU
if torch.cuda.is_available():
    print(f"‚úÖ CUDA Available: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
    device = "cuda"
else:
    print("‚ö†Ô∏è  CUDA not available - training will be VERY slow on CPU!")
    device = "cpu"

---

## 1. Understanding LoRA üéØ

### The Traditional Finetuning Problem

**Full Finetuning:**
```
3B parameter model:
  ‚Ä¢ 3,000,000,000 parameters to update
  ‚Ä¢ Requires ~12 GB VRAM (model weights)
  ‚Ä¢ Requires ~24 GB VRAM (optimizer states, gradients)
  ‚Ä¢ Total: ~36-40 GB VRAM needed
  ‚Ä¢ Training time: ~8-12 hours on consumer GPU
```

**Problems:**
- ‚ùå Most GPUs don't have 40 GB VRAM
- ‚ùå Very slow (must update billions of parameters)
- ‚ùå Expensive (need powerful hardware)
- ‚ùå Risk of catastrophic forgetting

### LoRA: Low-Rank Adaptation

**Key Insight:** Most weight updates during finetuning are **low-rank**!

Instead of updating the entire weight matrix `W`, we learn a small update:

```
Traditional Update:
  W_new = W_original + ŒîW
  
  Where ŒîW is [4096 √ó 4096] = 16,777,216 parameters ‚ùå

LoRA Update:
  W_new = W_original + A √ó B
  
  Where:
    A is [4096 √ó rank] = 4096 √ó 64 = 262,144 parameters
    B is [rank √ó 4096] = 64 √ó 4096 = 262,144 parameters
    Total: 524,288 parameters (97% reduction!) ‚úÖ
```

### Visual Explanation

```
Original Weight Matrix W [4096 √ó 4096]:
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                         ‚îÇ
‚îÇ     16M parameters      ‚îÇ
‚îÇ     (frozen ‚ùÑÔ∏è)          ‚îÇ
‚îÇ                         ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò

LoRA Decomposition:
        Matrix A              Matrix B
        [4096√ó64]            [64√ó4096]
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îê              ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ    ‚îÇ              ‚îÇ                 ‚îÇ
‚îÇ A  ‚îÇ      √ó       ‚îÇ       B         ‚îÇ
‚îÇ    ‚îÇ              ‚îÇ                 ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îò              ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
262K params         262K params
(trainable üî•)      (trainable üî•)

Final Update:
  W' = W + A √ó B
```

### The Mathematics

**Forward Pass:**
```
Output = (W + A √ó B) √ó Input
       = W √ó Input + A √ó B √ó Input
       = Original_Output + LoRA_Adaptation
```

**Key Parameters:**
- **rank (r)**: Dimension of low-rank matrices (typically 8, 16, 32, or 64)
  - Lower rank: Fewer parameters, faster training, may limit capacity
  - Higher rank: More parameters, slower training, more expressive
  
- **alpha (Œ±)**: Scaling factor for LoRA updates
  - Update is scaled by Œ±/r
  - Typically set to rank or 2√órank
  
- **target_modules**: Which layers to apply LoRA to
  - Usually attention layers (q_proj, k_proj, v_proj, o_proj)
  - Can include MLP layers (gate_proj, up_proj, down_proj)

### LoRA Benefits

**Memory:**
```
Full Finetuning:  40 GB VRAM
LoRA:             6-8 GB VRAM  (5-7x reduction!)
```

**Speed:**
```
Full Finetuning:  12 hours
LoRA:             2-3 hours    (4-6x faster!)
```

**Storage:**
```
Full Model:       12 GB
LoRA Adapters:    100-200 MB   (60-120x smaller!)
```

---

## 2. Loading the Dataset

First, let's load and split our medical conversation dataset:

In [None]:
# Load dataset
print("Loading MedSynth dataset...")
dataset = load_dataset("samhog/medsynth-diagnosis-icd10-10k", split="train")

# Split into train/val/test (70/15/15)
train_test_split = dataset.train_test_split(test_size=0.3, seed=42)
val_test_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42)

train_dataset = train_test_split['train']
val_dataset = val_test_split['train']
test_dataset = val_test_split['test']

print(f"\n‚úÖ Dataset Split:")
print(f"   Train: {len(train_dataset):,} examples")
print(f"   Val:   {len(val_dataset):,} examples")
print(f"   Test:  {len(test_dataset):,} examples")

# Show example
example = train_dataset[0]
print(f"\nExample Training Case:")
print(f"  Diagnosis: {example['diagnosis']}")
print(f"  Messages: {len(example['messages'])} turns")
for msg in example['messages'][:2]:
    print(f"    {msg['role']:8s}: {msg['content'][:50]}...")

---

## 3. Loading the Base Model

We'll load Llama 3.2 3B with 4-bit quantization to save memory:

In [None]:
# Model to finetune
model_name = "meta-llama/Llama-3.2-3B-Instruct"

print(f"Loading base model: {model_name}")
print("This may take 1-2 minutes...\n")

# Quantization config (same as LLM evaluation)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Important for causal LM training
print(f"‚úÖ Tokenizer loaded")

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)
print(f"‚úÖ Base model loaded and quantized")

# Check memory
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1024**3
    print(f"\nGPU Memory: {allocated:.2f} GB")

### Prepare Model for LoRA Training

4-bit models need special preparation before training:

In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

print("‚úÖ Model prepared for k-bit training")
print("\n   What this does:")
print("   ‚Ä¢ Enables gradient checkpointing (saves memory)")
print("   ‚Ä¢ Casts layernorm to float32 (stability)")
print("   ‚Ä¢ Enables input embedding gradients")

---

## 4. Configuring LoRA

Now the crucial step - configuring LoRA parameters:

In [None]:
# LoRA Configuration
lora_config = LoraConfig(
    r=64,                          # Rank of LoRA matrices (higher = more capacity)
    lora_alpha=128,                # Scaling factor (typically 2√órank)
    target_modules=[               # Which layers to apply LoRA to
        "q_proj",                  # Query projection (attention)
        "k_proj",                  # Key projection (attention)
        "v_proj",                  # Value projection (attention)
        "o_proj",                  # Output projection (attention)
        "gate_proj",               # Gate projection (MLP)
        "up_proj",                 # Up projection (MLP)
        "down_proj"                # Down projection (MLP)
    ],
    lora_dropout=0.05,             # Dropout for regularization
    bias="none",                   # Don't train bias terms
    task_type=TaskType.CAUSAL_LM   # Task type (causal language modeling)
)

print("LoRA Configuration:")
print(f"  Rank (r):          {lora_config.r}")
print(f"  Alpha (Œ±):         {lora_config.lora_alpha}")
print(f"  Scaling (Œ±/r):     {lora_config.lora_alpha / lora_config.r}")
print(f"  Target Modules:    {len(lora_config.target_modules)} types")
print(f"  Dropout:           {lora_config.lora_dropout}")
print(f"\n  Expected trainable params: ~0.3% of total")

### Apply LoRA to Model

In [None]:
# Apply LoRA configuration to model
model = get_peft_model(model, lora_config)

print("‚úÖ LoRA applied to model\n")

# Print trainable parameters
model.print_trainable_parameters()

# Detailed breakdown
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())

print(f"\nDetailed Breakdown:")
print(f"  Total parameters:     {total:,}")
print(f"  Trainable parameters: {trainable:,}")
print(f"  Percentage trainable: {trainable/total*100:.2f}%")
print(f"\n  Memory saved: ~{(total - trainable) / total * 100:.0f}% reduction in optimizer memory!")

### Understanding the Architecture

Let's inspect which modules got LoRA adapters:

In [None]:
# Count LoRA modules
lora_modules = [name for name, module in model.named_modules() if 'lora' in name.lower()]

print(f"LoRA Modules Added: {len(lora_modules)}\n")

# Show first 10 as examples
print("Example LoRA modules:")
for name in lora_modules[:10]:
    print(f"  ‚Ä¢ {name}")

if len(lora_modules) > 10:
    print(f"  ... and {len(lora_modules) - 10} more")

---

## 5. Data Preprocessing

Format our conversations for training:

In [None]:
def format_for_training(example: Dict) -> Dict:
    """
    Format a conversation example for training.
    """
    system_prompt = (
        "You are a medical diagnosis assistant. "
        "Based on the doctor-patient conversation, predict the ICD-10 diagnosis code."
    )
    
    # Format conversation
    conversation_text = "\n".join([
        f"{msg['role'].capitalize()}: {msg['content']}"
        for msg in example['messages']
    ])
    
    # Build chat
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": conversation_text},
        {"role": "assistant", "content": example['diagnosis']}
    ]
    
    # Apply chat template
    formatted = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        max_length=512,
        truncation=True,
        return_tensors="pt"
    )
    
    return {
        "input_ids": formatted[0],
        "labels": formatted[0].clone()  # For causal LM, labels = input_ids
    }

print("Preprocessing datasets...\n")

# Process datasets
train_formatted = [format_for_training(ex) for ex in tqdm(train_dataset, desc="Train")]
val_formatted = [format_for_training(ex) for ex in tqdm(val_dataset, desc="Val")]

print(f"\n‚úÖ Preprocessing complete")
print(f"   Train: {len(train_formatted)} examples")
print(f"   Val:   {len(val_formatted)} examples")

### Create PyTorch Datasets

In [None]:
from torch.utils.data import Dataset

class MedicalDataset(Dataset):
    """Simple dataset wrapper for formatted examples."""
    
    def __init__(self, formatted_examples: List[Dict]):
        self.examples = formatted_examples
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

train_torch_dataset = MedicalDataset(train_formatted)
val_torch_dataset = MedicalDataset(val_formatted)

print(f"‚úÖ PyTorch datasets created")
print(f"   Train size: {len(train_torch_dataset)}")
print(f"   Val size:   {len(val_torch_dataset)}")

---

## 6. Training Configuration

Set up training hyperparameters:

In [None]:
# Output directory for checkpoints
output_dir = project_root / "models" / "llama-3.2-3b-medical-lora"
output_dir.mkdir(parents=True, exist_ok=True)

# Training arguments
training_args = TrainingArguments(
    # Output
    output_dir=str(output_dir),
    
    # Training schedule
    num_train_epochs=3,                    # Number of passes through dataset
    per_device_train_batch_size=8,        # Batch size per GPU
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,         # Effective batch size = 8 √ó 4 = 32
    
    # Learning rate
    learning_rate=2e-4,                    # LoRA typically uses higher LR than full finetuning
    lr_scheduler_type="cosine",            # Cosine decay schedule
    warmup_ratio=0.03,                     # 3% warmup
    
    # Optimization
    optim="adamw_torch",                   # AdamW optimizer
    weight_decay=0.01,                     # L2 regularization
    max_grad_norm=1.0,                     # Gradient clipping
    
    # Precision
    bf16=True,                             # Use bfloat16 (faster on modern GPUs)
    
    # Evaluation
    evaluation_strategy="steps",           # Evaluate periodically
    eval_steps=100,                        # Evaluate every 100 steps
    
    # Saving
    save_strategy="steps",
    save_steps=200,                        # Save checkpoint every 200 steps
    save_total_limit=3,                    # Keep only 3 best checkpoints
    load_best_model_at_end=True,           # Load best checkpoint at end
    
    # Logging
    logging_steps=10,                      # Log every 10 steps
    logging_dir=str(output_dir / "logs"),
    report_to=[],                          # Disable wandb/tensorboard for demo
    
    # Memory optimization
    gradient_checkpointing=True,           # Trade compute for memory
    
    # Misc
    seed=42,
    remove_unused_columns=False,
)

print("Training Configuration:")
print(f"\nSchedule:")
print(f"  Epochs:                 {training_args.num_train_epochs}")
print(f"  Steps per epoch:        ~{len(train_torch_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)}")
print(f"  Total training steps:   ~{len(train_torch_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

print(f"\nBatch Size:")
print(f"  Per device:             {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation:  {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size:   {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")

print(f"\nLearning Rate:")
print(f"  Initial LR:             {training_args.learning_rate}")
print(f"  Scheduler:              {training_args.lr_scheduler_type}")
print(f"  Warmup:                 {training_args.warmup_ratio*100:.0f}% of steps")

print(f"\nOutput:")
print(f"  Checkpoint dir:         {output_dir}")

---

## 7. Training Loop

Now let's train! This uses HuggingFace Trainer for convenience:

In [None]:
# Data collator for dynamic padding
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # Not masked language modeling (we're doing causal LM)
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_torch_dataset,
    eval_dataset=val_torch_dataset,
    data_collator=data_collator,
)

print("‚úÖ Trainer created")
print(f"\n{'='*70}")
print("Starting Training...")
print(f"{'='*70}\n")

In [None]:
# Start training!
# This will take 2-3 hours depending on GPU
train_result = trainer.train()

print(f"\n{'='*70}")
print("‚úÖ Training Complete!")
print(f"{'='*70}")

# Print training summary
print(f"\nTraining Summary:")
print(f"  Total time:           {train_result.metrics['train_runtime']:.0f} seconds")
print(f"  Samples per second:   {train_result.metrics['train_samples_per_second']:.2f}")
print(f"  Final loss:           {train_result.metrics['train_loss']:.4f}")

### Understanding Training Metrics

**Loss:** How "wrong" the model's predictions are
- Lower is better
- Should decrease over training
- Typical range: 0.5-2.0 for well-trained models

**Learning Rate Schedule:**
```
LR
‚îÇ     
‚îÇ   ‚ï±‚îÄ‚îÄ‚ï≤
‚îÇ  ‚ï±    ‚ï≤___
‚îÇ ‚ï±         ‚ï≤___
‚îÇ‚ï±              ‚ï≤____
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚Üí Steps
  Warmup    Cosine Decay
```

**Gradient Accumulation:**
- Batch size 8, accumulate 4 steps
- Effective batch size: 32
- Updates every 4 forward passes
- Allows larger effective batch size with limited memory

---

## 8. Visualizing Training Progress

Let's plot the training metrics:

In [None]:
# Load training logs
import json

log_history = trainer.state.log_history

# Extract metrics
train_losses = []
eval_losses = []
learning_rates = []
steps = []

for entry in log_history:
    if 'loss' in entry:  # Training step
        steps.append(entry['step'])
        train_losses.append(entry['loss'])
        learning_rates.append(entry.get('learning_rate', None))
    if 'eval_loss' in entry:  # Evaluation step
        eval_losses.append((entry['step'], entry['eval_loss']))

print(f"Training logs: {len(train_losses)} training steps, {len(eval_losses)} evaluation steps")

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Training loss
ax1.plot(steps, train_losses, label='Training Loss', color='#3498db', linewidth=2)
if eval_losses:
    eval_steps, eval_vals = zip(*eval_losses)
    ax1.plot(eval_steps, eval_vals, label='Validation Loss', color='#e74c3c', linewidth=2, marker='o')
ax1.set_xlabel('Training Steps')
ax1.set_ylabel('Loss')
ax1.set_title('Training Progress: Loss Over Time')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Learning rate
lr_steps = [s for s, lr in zip(steps, learning_rates) if lr is not None]
lr_vals = [lr for lr in learning_rates if lr is not None]
ax2.plot(lr_steps, lr_vals, label='Learning Rate', color='#2ecc71', linewidth=2)
ax2.set_xlabel('Training Steps')
ax2.set_ylabel('Learning Rate')
ax2.set_title('Learning Rate Schedule')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nüí° What to look for:")
print("   ‚Ä¢ Loss should decrease steadily")
print("   ‚Ä¢ Validation loss should track training loss")
print("   ‚Ä¢ LR should warm up then decay")
print("   ‚Ä¢ If validation >> training, model is overfitting")

---

## 9. Saving the Model

Save the LoRA adapters (not the full model!):

In [None]:
# Save final model
final_model_dir = output_dir / "final_model"

# Save LoRA adapters
model.save_pretrained(final_model_dir)
tokenizer.save_pretrained(final_model_dir)

print(f"‚úÖ Model saved to: {final_model_dir}")

# Check file sizes
import os

total_size = 0
for root, dirs, files in os.walk(final_model_dir):
    for file in files:
        filepath = os.path.join(root, file)
        total_size += os.path.getsize(filepath)

print(f"\nModel Files:")
for file in os.listdir(final_model_dir):
    filepath = final_model_dir / file
    if filepath.is_file():
        size = filepath.stat().st_size / 1024**2  # MB
        print(f"  ‚Ä¢ {file:30s} {size:8.1f} MB")

print(f"\n  Total size: {total_size / 1024**2:.1f} MB")
print(f"\n  Compare to full model: ~12,000 MB")
print(f"  Space saved: {(1 - total_size / (12000 * 1024**2)) * 100:.1f}%")

### What's Saved?

**LoRA adapters only:**
- `adapter_model.safetensors` - LoRA weights (A and B matrices)
- `adapter_config.json` - LoRA configuration
- Tokenizer files

**NOT saved:**
- Base model weights (these stay frozen)
- You load base model + adapters later

**Loading later:**
```python
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

# Load LoRA adapters
model = PeftModel.from_pretrained(base_model, "path/to/adapters")
```

---

## 10. Quick Test

Let's test the finetuned model on a validation example:

In [None]:
# Test on validation example
test_example = val_dataset[0]

# Format for inference
system_prompt = (
    "You are a medical diagnosis assistant. "
    "Based on the doctor-patient conversation below, predict ONLY the ICD-10 diagnosis code. "
    "Respond with just the code (e.g., 'J06.9'), nothing else."
)

conversation_text = "\n".join([
    f"{msg['role'].capitalize()}: {msg['content']}"
    for msg in test_example['messages']
])

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": conversation_text}
]

formatted = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

# Tokenize and generate
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=10,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )

# Decode
generated = outputs[0][inputs['input_ids'].shape[1]:]
prediction = tokenizer.decode(generated, skip_special_tokens=True).strip()

print("Test Prediction:")
print(f"\nConversation snippet:")
for msg in test_example['messages'][:2]:
    print(f"  {msg['role']:8s}: {msg['content'][:60]}...")

print(f"\nGround Truth: {test_example['diagnosis']}")
print(f"Prediction:   {prediction}")
print(f"\nMatch: {'‚úÖ CORRECT!' if prediction.split()[0] == test_example['diagnosis'] else '‚ùå Incorrect'}")

---

## 11. Key Takeaways üí°

### What We Learned

1. **LoRA is Revolutionary**
   - 97% fewer trainable parameters
   - 5-7x less VRAM needed
   - 4-6x faster training
   - 99%+ smaller checkpoint files

2. **How LoRA Works**
   - Low-rank decomposition: `W' = W + A √ó B`
   - Freeze original weights, train small adapters
   - Apply to attention and MLP layers

3. **Training Process**
   - Load base model with quantization
   - Apply LoRA configuration
   - Format data with chat templates
   - Train with HuggingFace Trainer
   - Monitor loss and learning rate
   - Save adapters (not full model)

4. **Key Hyperparameters**
   - **rank (r)**: Controls adapter capacity (64 is good default)
   - **alpha**: Scaling factor (typically 2√órank)
   - **learning_rate**: Higher for LoRA (2e-4) than full finetuning (5e-5)
   - **batch_size**: Effective size matters (use gradient accumulation)

### Training Best Practices

‚úÖ **Do:**
- Use 4-bit quantization to save memory
- Apply LoRA to attention + MLP layers
- Use cosine LR schedule with warmup
- Monitor both training and validation loss
- Save checkpoints regularly

‚ùå **Don't:**
- Set rank too low (<16) or too high (>128)
- Use tiny batch sizes without accumulation
- Skip gradient clipping
- Ignore validation loss (overfitting!)
- Train for too many epochs (3-5 usually enough)

### Expected Results

**Training Time:**
- RTX 5090 (31GB): ~2-3 hours
- RTX 4090 (24GB): ~3-4 hours  
- RTX 3090 (24GB): ~4-6 hours

**Final Loss:**
- Training loss: ~0.5-1.0
- Validation loss: ~0.6-1.2
- If validation >> training: reduce epochs or add dropout

---

## 12. Memory Cleanup

Free GPU memory before moving to next notebook:

In [None]:
# Clean up
del model
del tokenizer
del trainer

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    torch.cuda.ipc_collect()

print("‚úÖ Memory freed")

---

## 13. What's Next? üëâ

We've successfully finetuned a 3B model! Now:

1. **Evaluate Finetuned Model** - How much did performance improve?
   - Compare with zero-shot LLM baseline
   - Measure exact accuracy, F1, etc.
   - Analyze error patterns

2. **Compare Results** - Does specialization beat size?
   - Finetuned 3B vs Untrained 8B
   - Performance vs Speed vs Memory
   - Visualize trade-offs

3. **Test on Custom Cases** - Real-world testing
   - Your own medical conversations
   - Interactive comparison

**Next Notebook:** [05_SLM_Evaluation_Finetuned.ipynb](05_SLM_Evaluation_Finetuned.ipynb)

---

## Summary

In this notebook, we:

- ‚úÖ Understood LoRA and low-rank adaptation
- ‚úÖ Configured LoRA with optimal hyperparameters
- ‚úÖ Prepared model for k-bit training
- ‚úÖ Preprocessed medical conversation data
- ‚úÖ Trained with HuggingFace Trainer
- ‚úÖ Monitored training progress
- ‚úÖ Saved LoRA adapters
- ‚úÖ Tested finetuned model

**Key Files in Project:**
- `src/training/trainer.py` - Training logic and LoRA setup
- `src/config/base_config.py` - LoRA and training hyperparameters
- `models/*/` - Saved LoRA adapters and checkpoints

---

**Continue to:** [05_SLM_Evaluation_Finetuned.ipynb](05_SLM_Evaluation_Finetuned.ipynb) üöÄ