## 1. Setup and Imports

# Retrieval Model Distillation Driver

## Plan Overview

**Teacher**: infly/inf-retriever-v1-pro (3584d, MRL-trained)

**Student**: sentence-transformers/all-mpnet-base-v2 (768d)

**Two Distillation Modes:**

### Mode 1: With Projection (Higher Quality)
- Architecture: Student (768d) + Projection (768→1536→3584)
- Teacher target: Full 3584d embeddings
- Best for: Maximum quality, when computational cost is acceptable

### Mode 2: MRL-Based (More Efficient)
- Architecture: Student (768d) only, no projection
- Teacher target: First 768d of teacher embeddings (MRL slicing)
- Best for: Efficiency, leverages teacher's Matryoshka training
- Lower computational cost, smaller model size

**Phase 1**: General Distillation
- Loss: MSE (0.4) + Cosine (0.6)
- Data: MS MARCO + NQ + HotpotQA

**Phase 2**: Task-Specific
- Loss: InfoNCE (0.8) + MSE (0.2)
- Data: MS MARCO with hard negatives
- Temperature: 0.02

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "4,5,6,7"

In [None]:
import os
import torch
from sentence_transformers import SentenceTransformer
import logging

# Import from distill module
from distill import (
    ProjectionLayer,
    StudentModelWithProjection,
    train_phase1,
    train_phase2,
    evaluate_retrieval
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available GPUs: {torch.cuda.device_count()}")

## 2. Configuration

In [None]:
# Model configuration
TEACHER_MODEL = "infly/inf-retriever-v1-pro"
STUDENT_MODEL = "sentence-transformers/all-mpnet-base-v2"
TEACHER_DIM = 3584
STUDENT_DIM = 768
PROJECTION_HIDDEN_DIM = 1536

# ============================================================================
# DISTILLATION MODE CONFIGURATION
# ============================================================================
# Set USE_PROJECTION to choose between two distillation approaches:
#
# Option 1: WITH PROJECTION (use_projection=True)
#   - Student (768d) -> Projection Layer -> 3584d output
#   - Compared against full teacher embeddings (3584d)
#   - Best for maximum quality, higher computational cost
#
# Option 2: MRL-BASED (use_projection=False)  
#   - Student (768d) -> Direct output (no projection)
#   - Compared against teacher's first 768d (MRL slicing)
#   - Best for efficiency, lower computational cost
#   - Leverages teacher's Matryoshka Representation Learning (MRL) training
# ============================================================================
USE_PROJECTION = False  # Set to False for MRL-based distillation

# Phase 1 configuration
PHASE1_CONFIG = {
    'batch_size': 128,
    'learning_rate': 2e-5,
    'warmup_steps': 1000,
    'num_epochs': 1,
    'mse_weight': 0.4,
    'cosine_weight': 0.6,
    'max_length': 512,
    'gradient_accumulation_steps': 4,
    'max_samples_per_dataset': 10000
}

# Phase 2 configuration
PHASE2_CONFIG = {
    'batch_size': 16,
    'learning_rate': 5e-6,
    'warmup_steps': 500,
    'num_epochs': 1,
    'infonce_weight': 0.8,
    'mse_weight': 0.2,
    'temperature': 0.02,
    'max_length': 512,
    'num_negatives': 7,
    'gradient_accumulation_steps': 8,
    'max_samples': 100
}

# Paths
OUTPUT_DIR = "./checkpoints"
PHASE1_CHECKPOINT = os.path.join(OUTPUT_DIR, "phase1_best")
PHASE2_CHECKPOINT = os.path.join(OUTPUT_DIR, "phase2_best")
os.makedirs(OUTPUT_DIR, exist_ok=True)

## Alternative: Multi-GPU Training with DDP

**This notebook runs on a single GPU.** For faster training with multiple GPUs, use the DDP launcher script:

### Quick Start - Multi-GPU Training

```bash
# Training with 8 GPUs (recommended)
cd /home/yuanchu/src/distill
torchrun --nproc_per_node=8 train_ddp.py

# OR use the shell script
./run_ddp.sh 8 True  # 8 GPUs, with projection
./run_ddp.sh 8 False  # 8 GPUs, MRL mode (no projection)
```

### Performance Comparison

| Configuration | Training Time (Phase 1 + 2) | Speedup |
|--------------|------------------------------|---------|
| Single GPU | ~32 hours | 1x |
| 8 GPU DDP | ~4-5 hours | ~7-7.5x |
| 4 GPU DDP | ~8-9 hours | ~3.8x |

### Key Advantages of DDP:
- **7-7.5x speedup** on 8 GPUs vs single GPU
- **Balanced GPU utilization** (all GPUs ~95-98%)
- **No GPU bottleneck** (unlike DataParallel)
- **Efficient gradient synchronization** via NCCL All-Reduce

### Configuration
Edit `train_ddp.py` or pass command-line arguments:
```bash
torchrun --nproc_per_node=8 train_ddp.py \
    --use_projection=True \
    --batch_size=64 \
    --phase1_epochs=1 \
    --phase2_epochs=3 \
    --phase1_lr=2e-5
```

See `README_DDP.md` for complete documentation.

---

**Continue below for single-GPU notebook training:**

## 3. Load Models

In [None]:
# Load teacher model (frozen)
logger.info(f"Loading teacher model: {TEACHER_MODEL}")
teacher_model = SentenceTransformer(TEACHER_MODEL)
teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False
teacher_model = teacher_model.to(device)

# Create student model with optional projection
logger.info(f"Loading student model: {STUDENT_MODEL}")
logger.info(f"Distillation mode: {'WITH PROJECTION' if USE_PROJECTION else 'MRL-BASED (no projection)'}")

if USE_PROJECTION:
    # Create projection layer for upsampling student embeddings
    projection_layer = ProjectionLayer(STUDENT_DIM, PROJECTION_HIDDEN_DIM, TEACHER_DIM)
    student_model = StudentModelWithProjection(
        STUDENT_MODEL, 
        projection_layer=projection_layer,
        use_projection=True
    )
else:
    # No projection - use MRL slicing approach
    student_model = StudentModelWithProjection(
        STUDENT_MODEL,
        projection_layer=None,
        use_projection=False
    )

student_model = student_model.to(device)

# Test models
test_text = ["This is a test sentence."]
with torch.no_grad():
    teacher_emb = teacher_model.encode(test_text, convert_to_tensor=True)
    student_emb_base = student_model.encode(test_text, return_projected=False)
    student_emb_output = student_model.encode(test_text, return_projected=True)
    
print(f"\nModel test successful:")
print(f"  Teacher (full): {teacher_emb.shape}")
print(f"  Student (base): {student_emb_base.shape}")
print(f"  Student (output): {student_emb_output.shape}")

if USE_PROJECTION:
    print(f"\n  Mode: WITH PROJECTION")
    print(f"    Student {STUDENT_DIM}d -> Projection -> {TEACHER_DIM}d")
    print(f"    Compared against: Teacher full {TEACHER_DIM}d")
else:
    print(f"\n  Mode: MRL-BASED (no projection)")
    print(f"    Student {STUDENT_DIM}d (direct)")
    print(f"    Compared against: Teacher's first {STUDENT_DIM}d")

# Count parameters
student_params = sum(p.numel() for p in student_model.student.parameters())
if USE_PROJECTION:
    projection_params = sum(p.numel() for p in student_model.projection.parameters())
    print(f"\nParameter counts:")
    print(f"  Student base: {student_params:,}")
    print(f"  Projection: {projection_params:,}")
    print(f"  Total: {student_params + projection_params:,}")
else:
    print(f"\nParameter counts:")
    print(f"  Student base: {student_params:,}")
    print(f"  Projection: N/A (not used)")
    print(f"  Total: {student_params:,}")

## 4. Phase 1: General Distillation

In [None]:
# Run Phase 1 training
logger.info("Starting Phase 1: General Distillation")
logger.info(f"Config: {PHASE1_CONFIG}")

student_model = train_phase1(
    student_model, 
    teacher_model, 
    PHASE1_CONFIG, 
    device, 
    PHASE1_CHECKPOINT
)

logger.info(f"Phase 1 complete. Best model saved to: {PHASE1_CHECKPOINT}.pt")

## 5. Load Phase 1 Checkpoint and Evaluate

In [None]:
# Load Phase 1 checkpoint (if resuming or skipping Phase 1)
# Uncomment if you want to load a saved Phase 1 model
# checkpoint = torch.load(f"{PHASE1_CHECKPOINT}.pt")
# student_model.load_state_dict(checkpoint['model_state_dict'])
# logger.info(f"Loaded Phase 1 checkpoint from: {PHASE1_CHECKPOINT}.pt")

# Quick evaluation after Phase 1
logger.info("\nEvaluating Phase 1 model...")
test_queries = [
    "What is the capital of France?",
    "How does photosynthesis work?",
    "Explain quantum mechanics"
]

with torch.no_grad():
    # Get teacher embeddings
    teacher_embs = teacher_model.encode(test_queries, convert_to_tensor=True, normalize_embeddings=True)
    
    # Get student embeddings
    student_embs = student_model.encode(test_queries, normalize=True, return_projected=True)
    
    # Slice teacher embeddings if using MRL mode
    if not USE_PROJECTION:
        target_dim = student_model.get_output_dim()
        teacher_embs_sliced = teacher_embs[:, :target_dim]
        teacher_embs_sliced = torch.nn.functional.normalize(teacher_embs_sliced, p=2, dim=-1)
        # Calculate cosine similarity with sliced teacher
        cosine_sim = torch.nn.functional.cosine_similarity(teacher_embs_sliced, student_embs, dim=1)
        print(f"Mode: MRL-based (comparing {target_dim}d embeddings)")
    else:
        # Calculate cosine similarity with full teacher
        cosine_sim = torch.nn.functional.cosine_similarity(teacher_embs, student_embs, dim=1)
        print(f"Mode: With projection (comparing {TEACHER_DIM}d embeddings)")
    
    print(f"Average cosine similarity with teacher: {cosine_sim.mean().item():.4f}")
    print(f"Per-query similarities: {cosine_sim.tolist()}")

## 6. Phase 2: Task-Specific Training

In [None]:
# Run Phase 2 training
student_model = train_phase2(
    student_model, 
    teacher_model, 
    PHASE2_CONFIG, 
    device, 
    PHASE2_CHECKPOINT
)

## 7. Final Evaluation

In [None]:
## Save Model to Artifacts with Proper SentenceTransformer Format

from distill.save_model import save_distilled_model_to_artifacts

# Save the model to artifacts/ (model name auto-generated based on mode)
output_path = save_distilled_model_to_artifacts(
    student_model=student_model,
    checkpoint_path=f"{PHASE2_CHECKPOINT}.pt",
    artifacts_dir="./artifacts",
    model_name=None  # Auto-generated: "distilled-mpnet-3584d" or "distilled-mpnet-768d-mrl"
)

print(f"\n{'='*80}")
print("Model saved and ready to use!")
print(f"{'='*80}")
print(f"\nLoad it in any script with:")
print(f">>> from sentence_transformers import SentenceTransformer")
print(f">>> model = SentenceTransformer('{output_path}')")

if USE_PROJECTION:
    print(f">>> embeddings = model.encode(['your texts here'])  # Shape: (N, {TEACHER_DIM})")
else:
    print(f">>> embeddings = model.encode(['your texts here'])  # Shape: (N, {STUDENT_DIM})")

In [None]:
# Test loading the saved model
from sentence_transformers import SentenceTransformer

# Use the correct model path based on mode
if USE_PROJECTION:
    model_path = './artifacts/distilled-mpnet-3584d'
else:
    model_path = './artifacts/distilled-mpnet-768d-mrl'

print(f"Loading model from: {model_path}")
model = SentenceTransformer(model_path)
embeddings = model.encode(['your texts here'])
print(f"Embeddings shape: {embeddings.shape}")  # (1, 3584) or (1, 768) depending on mode

## 8. Save Final Model

In [None]:
# Load best Phase 2 checkpoint
checkpoint = torch.load(f"{PHASE2_CHECKPOINT}.pt")
student_model.load_state_dict(checkpoint['model_state_dict'])

# Save the complete model
final_model_path = os.path.join(OUTPUT_DIR, "distilled_mpnet_final")
os.makedirs(final_model_path, exist_ok=True)

# Save student base model
student_model.student.save(os.path.join(final_model_path, "student_base"))

# Save projection layer if using projection mode
if USE_PROJECTION:
    torch.save(
        student_model.projection.state_dict(),
        os.path.join(final_model_path, "projection_layer.pt")
    )
    logger.info(f"\nFinal model saved to: {final_model_path}")
    logger.info("  - student_base/: Base sentence-transformer model (768d)")
    logger.info("  - projection_layer.pt: Projection weights (768→3584)")
else:
    logger.info(f"\nFinal model saved to: {final_model_path}")
    logger.info("  - student_base/: Base sentence-transformer model (768d, MRL-distilled)")

print("\n" + "=" * 80)
print("Distillation Complete!")
print("=" * 80)

if USE_PROJECTION:
    print(f"\nYou can now use the model for inference:")
    print(f"  - For fast retrieval: Use 768d embeddings (student only)")
    print(f"  - For hybrid system: Use 3584d embeddings (student + projection)")
    print(f"  - Switch to teacher (3584d native) for hard cases")
else:
    print(f"\nYou can now use the model for inference:")
    print(f"  - The model produces 768d embeddings")
    print(f"  - Trained to match teacher's first 768d (MRL)")
    print(f"  - More efficient than projection-based approach")

## 9. Inference Example

In [None]:
# Example: Encode queries and documents
example_queries = [
    "What is machine learning?",
    "How does neural network work?"
]

example_docs = [
    "Machine learning is a subset of artificial intelligence.",
    "Neural networks are computing systems inspired by biological neural networks."
]

# Encode with student
with torch.no_grad():
    query_emb = student_model.encode(example_queries, normalize=True, return_projected=True)
    doc_emb = student_model.encode(example_docs, normalize=True, return_projected=True)
    
    similarities = torch.matmul(query_emb, doc_emb.T)
    output_dim = student_model.get_output_dim()
    mode_str = "with projection" if USE_PROJECTION else "MRL-based"
    print(f"\nSimilarity scores (student model, {output_dim}d, {mode_str}):")
    print(similarities)

# Compare with teacher
with torch.no_grad():
    teacher_query_emb = teacher_model.encode(example_queries, convert_to_tensor=True, normalize_embeddings=True)
    teacher_doc_emb = teacher_model.encode(example_docs, convert_to_tensor=True, normalize_embeddings=True)
    
    # If using MRL mode, slice teacher embeddings for fair comparison
    if not USE_PROJECTION:
        target_dim = student_model.get_output_dim()
        teacher_query_emb = teacher_query_emb[:, :target_dim]
        teacher_doc_emb = teacher_doc_emb[:, :target_dim]
        teacher_query_emb = torch.nn.functional.normalize(teacher_query_emb, p=2, dim=-1)
        teacher_doc_emb = torch.nn.functional.normalize(teacher_doc_emb, p=2, dim=-1)
        print(f"\nSimilarity scores (teacher model, first {target_dim}d, MRL slice):")
    else:
        print(f"\nSimilarity scores (teacher model, full {TEACHER_DIM}d):")
    
    teacher_similarities = torch.matmul(teacher_query_emb, teacher_doc_emb.T)
    print(teacher_similarities)
    
    print(f"\nAverage difference: {torch.abs(similarities - teacher_similarities).mean().item():.6f}")