# Retrieval Model Distillation Driver

## Plan Overview

**Teacher**: infly/inf-retriever-v1-pro (3584d)

**Student**: sentence-transformers/all-mpnet-base-v2 (768d) + Projection (768→1536→3584)

**Phase 1**: General Distillation
- Loss: MSE (0.4) + Cosine (0.6)
- Data: MS MARCO + NQ + HotpotQA

**Phase 2**: Task-Specific
- Loss: InfoNCE (0.8) + MSE (0.2)
- Data: MS MARCO with hard negatives
- Temperature: 0.02

## 1. Setup and Imports

In [None]:
import os
import torch
from sentence_transformers import SentenceTransformer
import logging

# Import from distill module
from distill import (
    ProjectionLayer,
    StudentModelWithProjection,
    train_phase1,
    train_phase2,
    evaluate_retrieval
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available GPUs: {torch.cuda.device_count()}")

## 2. Configuration

In [None]:
# Model configuration
TEACHER_MODEL = "infly/inf-retriever-v1-pro"
STUDENT_MODEL = "sentence-transformers/all-mpnet-base-v2"
TEACHER_DIM = 3584
STUDENT_DIM = 768
PROJECTION_HIDDEN_DIM = 1536

# Phase 1 configuration
PHASE1_CONFIG = {
    'batch_size': 64,
    'learning_rate': 2e-5,
    'warmup_steps': 1000,
    'num_epochs': 1,
    'mse_weight': 0.4,
    'cosine_weight': 0.6,
    'max_length': 512,
    'gradient_accumulation_steps': 4,
    'max_samples_per_dataset': 100000
}

# Phase 2 configuration
PHASE2_CONFIG = {
    'batch_size': 64,
    'learning_rate': 5e-6,
    'warmup_steps': 500,
    'num_epochs': 3,
    'infonce_weight': 0.8,
    'mse_weight': 0.2,
    'temperature': 0.02,
    'max_length': 512,
    'num_negatives': 7,
    'gradient_accumulation_steps': 8,
    'max_samples': 500000
}

# Paths
OUTPUT_DIR = "./checkpoints"
PHASE1_CHECKPOINT = os.path.join(OUTPUT_DIR, "phase1_best")
PHASE2_CHECKPOINT = os.path.join(OUTPUT_DIR, "phase2_best")
os.makedirs(OUTPUT_DIR, exist_ok=True)

## 3. Load Models

In [None]:
# Load teacher model (frozen)
logger.info(f"Loading teacher model: {TEACHER_MODEL}")
teacher_model = SentenceTransformer(TEACHER_MODEL)
teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False
teacher_model = teacher_model.to(device)

# Create student model with projection
logger.info(f"Loading student model: {STUDENT_MODEL}")
projection_layer = ProjectionLayer(STUDENT_DIM, PROJECTION_HIDDEN_DIM, TEACHER_DIM)
student_model = StudentModelWithProjection(STUDENT_MODEL, projection_layer)
student_model = student_model.to(device)

# Test models
test_text = ["This is a test sentence."]
with torch.no_grad():
    teacher_emb = teacher_model.encode(test_text, convert_to_tensor=True)
    student_emb_768 = student_model.encode(test_text, return_projected=False)
    student_emb_3584 = student_model.encode(test_text, return_projected=True)
    
print(f"\nModel test successful:")
print(f"  Teacher: {teacher_emb.shape}")
print(f"  Student (768d): {student_emb_768.shape}")
print(f"  Student (3584d): {student_emb_3584.shape}")

# Count parameters
student_params = sum(p.numel() for p in student_model.student.parameters())
projection_params = sum(p.numel() for p in student_model.projection.parameters())
print(f"\nParameter counts:")
print(f"  Student base: {student_params:,}")
print(f"  Projection: {projection_params:,}")
print(f"  Total: {student_params + projection_params:,}")

## 4. Phase 1: General Distillation

In [None]:
# Run Phase 1 training
logger.info("Starting Phase 1: General Distillation")
logger.info(f"Config: {PHASE1_CONFIG}")

student_model = train_phase1(
    student_model, 
    teacher_model, 
    PHASE1_CONFIG, 
    device, 
    PHASE1_CHECKPOINT
)

logger.info(f"Phase 1 complete. Best model saved to: {PHASE1_CHECKPOINT}.pt")

## 5. Load Phase 1 Checkpoint and Evaluate

In [None]:
# Load Phase 1 checkpoint (if resuming or skipping Phase 1)
# Uncomment if you want to load a saved Phase 1 model
# checkpoint = torch.load(f"{PHASE1_CHECKPOINT}.pt")
# student_model.load_state_dict(checkpoint['model_state_dict'])
# logger.info(f"Loaded Phase 1 checkpoint from: {PHASE1_CHECKPOINT}.pt")

# Quick evaluation after Phase 1
logger.info("\nEvaluating Phase 1 model...")
test_queries = [
    "What is the capital of France?",
    "How does photosynthesis work?",
    "Explain quantum mechanics"
]

with torch.no_grad():
    teacher_embs = teacher_model.encode(test_queries, convert_to_tensor=True, normalize_embeddings=True)
    student_embs = student_model.encode(test_queries, normalize=True, return_projected=True)
    
    # Calculate cosine similarity
    cosine_sim = torch.nn.functional.cosine_similarity(teacher_embs, student_embs, dim=1)
    print(f"Average cosine similarity with teacher: {cosine_sim.mean().item():.4f}")
    print(f"Per-query similarities: {cosine_sim.tolist()}")

## 6. Phase 2: Task-Specific Training

In [None]:
# Run Phase 2 training
student_model = train_phase2(
    student_model, 
    teacher_model, 
    PHASE2_CONFIG, 
    device, 
    PHASE2_CHECKPOINT
)

## 7. Final Evaluation

In [None]:
## Save Model to Artifacts with Proper SentenceTransformer Format

from distill.save_model import save_distilled_model_to_artifacts

# Save the model with projection to artifacts/
output_path = save_distilled_model_to_artifacts(
    student_model=student_model,
    checkpoint_path=f"{PHASE1_CHECKPOINT}.pt",
    artifacts_dir="./artifacts",
    model_name="distilled-mpnet-3584d"
)

print(f"\n{'='*80}")
print("Model saved and ready to use!")
print(f"{'='*80}")
print(f"\nLoad it in any script with:")
print(f">>> from sentence_transformers import SentenceTransformer")
print(f">>> model = SentenceTransformer('{output_path}')")
print(f">>> embeddings = model.encode(['your texts here'])  # Shape: (N, 3584)")

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('./artifacts/distilled-mpnet-3584d')
embeddings = model.encode(['your texts here'])  # Shape: (N, 3584)
embeddings.shape

## 8. Save Final Model

In [None]:
# Load best Phase 2 checkpoint
checkpoint = torch.load(f"{PHASE2_CHECKPOINT}.pt")
student_model.load_state_dict(checkpoint['model_state_dict'])

# Save the complete model
final_model_path = os.path.join(OUTPUT_DIR, "distilled_mpnet_final")
os.makedirs(final_model_path, exist_ok=True)

# Save student base model
student_model.student.save(os.path.join(final_model_path, "student_base"))

# Save projection layer
torch.save(
    student_model.projection.state_dict(),
    os.path.join(final_model_path, "projection_layer.pt")
)

logger.info(f"\nFinal model saved to: {final_model_path}")
logger.info("  - student_base/: Base sentence-transformer model (768d)")
logger.info("  - projection_layer.pt: Projection weights (768→3584)")

print("\n" + "=" * 80)
print("Distillation Complete!")
print("=" * 80)
print(f"\nYou can now use the model for inference:")
print(f"  - For fast retrieval: Use 768d embeddings (student only)")
print(f"  - For hybrid system: Use 3584d embeddings (student + projection)")
print(f"  - Switch to teacher (3584d native) for hard cases")

## 9. Inference Example

In [None]:
# Example: Encode queries and documents
example_queries = [
    "What is machine learning?",
    "How does neural network work?"
]

example_docs = [
    "Machine learning is a subset of artificial intelligence.",
    "Neural networks are computing systems inspired by biological neural networks."
]

# Encode with student (3584d for hybrid system)
with torch.no_grad():
    query_emb = student_model.encode(example_queries, normalize=True, return_projected=True)
    doc_emb = student_model.encode(example_docs, normalize=True, return_projected=True)
    
    similarities = torch.matmul(query_emb, doc_emb.T)
    print("\nSimilarity scores (student model, 3584d):")
    print(similarities)

# Compare with teacher
with torch.no_grad():
    teacher_query_emb = teacher_model.encode(example_queries, convert_to_tensor=True, normalize_embeddings=True)
    teacher_doc_emb = teacher_model.encode(example_docs, convert_to_tensor=True, normalize_embeddings=True)
    
    teacher_similarities = torch.matmul(teacher_query_emb, teacher_doc_emb.T)
    print("\nSimilarity scores (teacher model, 3584d):")
    print(teacher_similarities)
    
    print(f"\nAverage difference: {torch.abs(similarities - teacher_similarities).mean().item():.6f}")