# LoRA Configuration

This notebook configures LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning of EmbeddingGemma. LoRA allows us to fine-tune the model by training only a small number of additional parameters.


In [None]:
# Import functions from the scripts directory
from src.models.embedding_pipeline import load_embeddinggemma_model
from src.models.lora_setup import setup_lora_model, print_trainable_parameters


## Load Base Model

First, we load the base EmbeddingGemma model that we'll fine-tune.


In [None]:
# Load the base model
tokenizer, base_model = load_embeddinggemma_model()

print("Base model loaded successfully")
print(f"Model device: {next(base_model.parameters()).device}")


## Configure and Apply LoRA

We'll configure LoRA with:
- Rank (r=16): Controls the size of the adapter matrices
- Alpha (32): Scaling factor, typically 2*r
- Dropout (0.1): Regularization during training
- Target modules: ["q_proj", "k_proj", "v_proj"] - attention projection layers


In [None]:
# Setup LoRA on the model
# This freezes the base model and adds LoRA adapters
model = setup_lora_model(
    base_model,
    r=16,              # LoRA rank
    lora_alpha=32,     # Scaling factor
    lora_dropout=0.1,  # Dropout rate
    target_modules=["q_proj", "k_proj", "v_proj"]  # Attention projection layers
)

print("LoRA adapters applied successfully")


## Verify Trainable Parameters

After applying LoRA, only a small fraction of parameters should be trainable (typically < 1% of the total).


In [None]:
# Print trainable parameter statistics
stats = print_trainable_parameters(model)

print(f"\nParameter Efficiency:")
print(f"  - Only {stats['percentage']:.2f}% of parameters are trainable")
print(f"  - This means we can fine-tune with much less memory and compute!")


## Verify Model Still Works

Let's make sure the model with LoRA still produces embeddings correctly (it should behave identically to the base model before training).


In [None]:
# Test that the model still works
from src.models.embedding_pipeline import embed_texts

test_text = "This is a test sentence."
embedding = embed_texts(test_text, model, tokenizer)

print(f"Embedding shape: {embedding.shape}")
print(f"Embedding norm: {embedding.norm().item():.4f} (should be ~1.0 for normalized embeddings)")
print("âœ“ Model with LoRA is working correctly!")
