# LoRA Fine-Tuning EmbeddingGemma on PDF QA Pairs

This notebook fine-tunes EmbeddingGemma using LoRA on query-passage pairs generated from PDF documents. The model learns to:
- **Map queries to relevant passages**: Make embeddings of queries and their correct passages similar
- **Improve retrieval**: Enable better semantic search over PDF content
- **Domain adaptation**: Adapt the general EmbeddingGemma model to your specific PDF domain

This workflow uses the training pairs generated in notebook 10.


In [None]:
# Import all necessary functions from src modules
from src.data.loaders import load_query_passage_pairs, validate_pairs
from src.models.embedding_pipeline import load_embeddinggemma_model
from src.models.lora_setup import setup_lora_model, print_trainable_parameters
from src.training.trainer import train_model
from src.utils.paths import timestamped_path
import torch
import pandas as pd


## Step 1: Load Training Dataset

Load the query-passage pairs generated in notebook 10. The `load_query_passage_pairs()` function automatically maps 'query' → 'anchor' and 'passage' → 'positive' for compatibility with the training framework.


In [None]:
# Load training dataset from CSV file generated in notebook 10
# Replace with the actual path to your train CSV file
train_csv_path = "data/processed/pdf_query_passage_pairs_train_YYYYMMDD_HHMMSS.csv"  # Update this path

# Load pairs (automatically maps query→anchor, passage→positive)
train_data = load_query_passage_pairs(train_csv_path)

print(f"Loaded {len(train_data)} training pairs")
print("\nFirst pair:")
print(f"  Query (anchor):   '{train_data[0]['anchor']}'")
print(f"  Passage (positive): '{train_data[0]['positive'][:100]}...'")

# Validate dataset
stats = validate_pairs(train_data)
print(f"\nDataset statistics:")
print(f"  Average query length: {stats['avg_anchor_length']:.1f} characters")
print(f"  Average passage length: {stats['avg_positive_length']:.1f} characters")
print(f"  Has empty strings: {stats['has_empty']}")


## Step 2: Load Base Model and Apply LoRA

We'll load the base EmbeddingGemma model and configure it with LoRA adapters. LoRA allows us to fine-tune only a small subset of parameters, making training efficient and preventing catastrophic forgetting.

### LoRA Configuration

- `r`: Rank of LoRA adapters (default: 16, controls capacity)
- `lora_alpha`: Scaling factor (default: 32, typically 2x rank)
- `lora_dropout`: Dropout rate for LoRA layers (default: 0.1)
- `target_modules`: Which transformer layers to apply LoRA to (typically attention projections)


In [None]:
# Load base EmbeddingGemma model
tokenizer, base_model = load_embeddinggemma_model()

# Check device
device = next(base_model.parameters()).device
print(f"Model loaded on device: {device}")

# Apply LoRA configuration
model = setup_lora_model(
    base_model,
    r=16,  # LoRA rank (controls adapter capacity)
    lora_alpha=32,  # Scaling factor (typically 2x rank)
    lora_dropout=0.1,  # Dropout for regularization
    target_modules=["q_proj", "k_proj", "v_proj"]  # Apply to attention projections
)

# Verify only LoRA parameters are trainable
print_trainable_parameters(model)


## Step 3: Load Validation Dataset (Optional)

If you have a validation set, load it to monitor training progress and prevent overfitting.


In [None]:
# Load validation dataset (optional)
# Replace with the actual path to your val CSV file
val_csv_path = "data/processed/pdf_query_passage_pairs_val_YYYYMMDD_HHMMSS.csv"  # Update this path

try:
    val_data = load_query_passage_pairs(val_csv_path)
    print(f"Loaded {len(val_data)} validation pairs")
except FileNotFoundError:
    print("Validation file not found, skipping validation")
    val_data = None


## Step 4: Train the Model

Now we'll train the model using contrastive learning. The training process:
1. **Forward pass**: Compute embeddings for queries and passages
2. **Contrastive loss**: Use Multiple Negatives Ranking Loss to bring positive pairs closer
3. **Backward pass**: Update only LoRA parameters
4. **Validation**: Monitor performance on validation set (if available)

### Training Parameters

- `epochs`: Number of training epochs
- `batch_size`: Batch size for training
- `learning_rate`: Learning rate for optimizer
- `max_length`: Maximum sequence length for tokenization


In [None]:
# Train the model
# The trainer uses Multiple Negatives Ranking Loss for contrastive learning
trained_model = train_model(
    model=model,
    tokenizer=tokenizer,
    train_data=train_data,
    val_data=val_data,  # None if validation not available
    epochs=3,  # Adjust based on dataset size
    batch_size=8,  # Adjust based on GPU memory
    learning_rate=2e-4,  # Learning rate for LoRA fine-tuning
    max_length=512,  # Should match chunk max_tokens
    device=device
)

print("Training completed!")


## Step 5: Save Fine-Tuned Model

Save the LoRA adapters so you can load them later for inference or further training.


In [None]:
# Save LoRA adapters
from peft import PeftModel

# Save adapters to timestamped directory
adapter_path = timestamped_path("outputs/models", "pdf_qa_lora_adapter", "")
adapter_path.mkdir(parents=True, exist_ok=True)

trained_model.save_pretrained(str(adapter_path))
tokenizer.save_pretrained(str(adapter_path))

print(f"Saved LoRA adapters to: {adapter_path}")
print("\nTo load the model later:")
print(f"  from peft import PeftModel")
print(f"  from src.models.embedding_pipeline import load_embeddinggemma_model")
print(f"  tokenizer, base_model = load_embeddinggemma_model()")
print(f"  model = PeftModel.from_pretrained(base_model, '{adapter_path}')")


## Step 6: Test Fine-Tuned Model

Let's test the fine-tuned model on a sample query to see if it retrieves the correct passage.


In [None]:
# Test the fine-tuned model
from src.models.embedding_pipeline import embed_texts
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# Example query from training data
test_query = train_data[0]['anchor']
correct_passage = train_data[0]['positive']

# Embed query and passage using fine-tuned model
query_emb = embed_texts(test_query, trained_model, tokenizer, device=device, max_length=512)
passage_emb = embed_texts(correct_passage, trained_model, tokenizer, device=device, max_length=512)

# Compute similarity
similarity = cosine_similarity(query_emb.numpy(), passage_emb.numpy())[0][0]

print(f"Query: '{test_query}'")
print(f"\nCorrect Passage: '{correct_passage[:150]}...'")
print(f"\nSimilarity score: {similarity:.4f}")
print(f"(Higher is better, should be close to 1.0 for correct pairs)")


## Summary

This notebook demonstrated:
1. ✅ Loading query-passage pairs from PDF-generated dataset
2. ✅ Setting up LoRA for efficient fine-tuning
3. ✅ Training EmbeddingGemma on PDF QA pairs
4. ✅ Saving fine-tuned LoRA adapters
5. ✅ Testing the fine-tuned model

**Next Steps:**
- Use the fine-tuned model for semantic search over your PDF documents
- Evaluate retrieval performance using the evaluation modules
- Consider using hard negatives (from notebook 10) for advanced training
- The fine-tuned model can be loaded and used for inference in other notebooks
