# Performance Tuning for Production

Optimize HybridRAG for production deployment.

**What you'll learn:**
- Search parameter optimization
- Caching strategies
- Batch processing
- Performance benchmarking
- Cost optimization

**Prerequisites:** Completed notebooks 01-02

In [None]:
# Setup
import asyncio
import time
from typing import List
from hybridrag import create_hybridrag
from hybridrag.config import get_settings
from dotenv import load_dotenv

load_dotenv()
rag = await create_hybridrag()
print("âœ“ HybridRAG initialized")

## 1. Baseline Performance Measurement

In [None]:
async def measure_query_latency(query: str, mode: str, top_k: int) -> float:
    """Measure single query latency"""
    start = time.time()
    results = await rag.query(query=query, mode=mode, top_k=top_k)
    latency = time.time() - start
    return latency

# Test queries
test_queries = [
    "What is MongoDB Atlas?",
    "How does vector search work?",
    "mongodb hybrid search configuration",
]

print("Baseline Latency Measurements:\n")

for mode in ["vector", "keyword", "hybrid"]:
    latencies = []
    for query in test_queries:
        latency = await measure_query_latency(query, mode, top_k=10)
        latencies.append(latency)
    
    avg_latency = sum(latencies) / len(latencies)
    print(f"{mode.upper()} Mode:")
    print(f"  Average: {avg_latency*1000:.2f}ms")
    print(f"  Min: {min(latencies)*1000:.2f}ms")
    print(f"  Max: {max(latencies)*1000:.2f}ms\n")

## 2. Parameter Optimization

In [None]:
# Test different top_k values
query = "mongodb atlas vector search"
top_k_values = [5, 10, 20, 50]

print("Impact of top_k on latency:\n")

for top_k in top_k_values:
    latency = await measure_query_latency(query, "hybrid", top_k)
    print(f"top_k={top_k:3d}: {latency*1000:6.2f}ms (numCandidates={top_k*20})")

print("\nRecommendation:")
print("- top_k=10: Good balance for most use cases")
print("- top_k=5: Faster, acceptable for quick searches")
print("- top_k=20+: Slower, use only when higher recall needed")

## 3. Batch Processing

In [None]:
# Compare sequential vs batch processing
queries = [
    "mongodb vector search",
    "atlas search features",
    "hybrid search setup",
    "knowledge graph mongodb",
    "rag system architecture",
]

# Sequential processing
start = time.time()
sequential_results = []
for query in queries:
    results = await rag.query(query=query, mode="hybrid", top_k=5)
    sequential_results.append(results)
sequential_time = time.time() - start

# Batch processing (parallel)
start = time.time()
batch_tasks = [rag.query(query=q, mode="hybrid", top_k=5) for q in queries]
batch_results = await asyncio.gather(*batch_tasks)
batch_time = time.time() - start

print(f"Sequential: {sequential_time*1000:.2f}ms ({sequential_time*1000/len(queries):.2f}ms per query)")
print(f"Batch: {batch_time*1000:.2f}ms ({batch_time*1000/len(queries):.2f}ms per query)")
print(f"Speedup: {sequential_time/batch_time:.2f}x\n")

print("Recommendation:")
print("- Use batch processing for multiple queries")
print("- Significant speedup for I/O-bound operations")
print("- Monitor MongoDB connection limits")

## 4. Caching Strategy

In [None]:
# Simple cache implementation
from functools import lru_cache
import hashlib

class QueryCache:
    def __init__(self, max_size: int = 100):
        self.cache = {}
        self.max_size = max_size
    
    def get_key(self, query: str, mode: str, top_k: int) -> str:
        """Generate cache key"""
        content = f"{query}:{mode}:{top_k}"
        return hashlib.md5(content.encode()).hexdigest()
    
    def get(self, query: str, mode: str, top_k: int):
        key = self.get_key(query, mode, top_k)
        return self.cache.get(key)
    
    def set(self, query: str, mode: str, top_k: int, results):
        if len(self.cache) >= self.max_size:
            # Simple FIFO eviction
            self.cache.pop(next(iter(self.cache)))
        key = self.get_key(query, mode, top_k)
        self.cache[key] = results

cache = QueryCache(max_size=100)

async def cached_query(query: str, mode: str = "hybrid", top_k: int = 10):
    """Query with caching"""
    cached = cache.get(query, mode, top_k)
    if cached:
        return cached, True  # Cache hit
    
    results = await rag.query(query=query, mode=mode, top_k=top_k)
    cache.set(query, mode, top_k, results)
    return results, False  # Cache miss

# Test caching
query = "mongodb vector search atlas"

# First query (cache miss)
start = time.time()
results1, hit1 = await cached_query(query)
time1 = time.time() - start

# Second query (cache hit)
start = time.time()
results2, hit2 = await cached_query(query)
time2 = time.time() - start

print(f"First query (miss): {time1*1000:.2f}ms")
print(f"Second query (hit): {time2*1000:.2f}ms")
print(f"Speedup: {time1/time2:.1f}x\n")

print("Caching Best Practices:")
print("- Cache frequently asked queries")
print("- Set appropriate TTL (e.g., 5-15 minutes)")
print("- Use Redis/Memcached for distributed caching")
print("- Monitor cache hit rate (target: >70%)")

## 5. Embedding Cost Optimization

In [None]:
# Calculate embedding costs
settings = get_settings()

# Voyage AI pricing (example)
VOYAGE_PRICING = {
    "voyage-3": 0.06 / 1_000_000,  # $0.06 per 1M tokens
    "voyage-3-large": 0.13 / 1_000_000,  # $0.13 per 1M tokens
}

def estimate_embedding_cost(texts: List[str], model: str = "voyage-3") -> dict:
    """Estimate embedding costs"""
    total_tokens = sum(len(text.split()) * 1.3 for text in texts)  # ~1.3 tokens per word
    cost_per_token = VOYAGE_PRICING.get(model, 0.06 / 1_000_000)
    total_cost = total_tokens * cost_per_token
    
    return {
        "total_texts": len(texts),
        "total_tokens": int(total_tokens),
        "cost": total_cost,
        "cost_per_text": total_cost / len(texts),
    }

# Example: 1000 documents
sample_docs = ["This is a sample document about MongoDB Atlas." * 10] * 1000

for model in ["voyage-3", "voyage-3-large"]:
    cost_estimate = estimate_embedding_cost(sample_docs, model)
    print(f"{model}:")
    print(f"  Documents: {cost_estimate['total_texts']:,}")
    print(f"  Tokens: {cost_estimate['total_tokens']:,}")
    print(f"  Total cost: ${cost_estimate['cost']:.4f}")
    print(f"  Cost per doc: ${cost_estimate['cost_per_text']:.6f}\n")

print("Cost Optimization Tips:")
print("- Use voyage-3 (smaller model) when possible")
print("- Batch embeddings (up to 128 texts per API call)")
print("- Cache embeddings - don't re-embed unchanged content")
print("- Chunk documents efficiently (512-1024 tokens per chunk)")

## 6. MongoDB Index Optimization

In [None]:
# Check MongoDB indexes
from pymongo import MongoClient

settings = get_settings()
client = MongoClient(settings.MONGODB_URI)
db = client[settings.MONGODB_DATABASE]
collection = db[settings.COLLECTION_NAME]

print("Current Indexes:\n")
for idx in collection.list_indexes():
    print(f"Name: {idx['name']}")
    print(f"Keys: {idx['key']}")
    if 'type' in idx:
        print(f"Type: {idx['type']}")
    print()

print("Index Optimization Tips:")
print("\nVector Search Index:")
print("- Use HNSW for better performance")
print("- Tune numDimensions to match embeddings (1024 for voyage-3)")
print("- Set similarity to 'cosine' for normalized vectors")

print("\nAtlas Search Index:")
print("- Index only fields you actually search")
print("- Use analyzers appropriate for your content")
print("- Enable autocomplete only if needed")

print("\nCompound Indexes:")
print("- Create indexes for common filter fields")
print("- Order matters: most selective field first")
print("- Monitor index usage with explain()")

client.close()

## 7. Performance Monitoring

In [None]:
import statistics

class PerformanceMonitor:
    def __init__(self):
        self.query_times = []
        self.query_counts = {"vector": 0, "keyword": 0, "hybrid": 0}
    
    def record_query(self, mode: str, latency: float):
        self.query_times.append(latency)
        self.query_counts[mode] += 1
    
    def get_stats(self) -> dict:
        if not self.query_times:
            return {}
        
        return {
            "total_queries": len(self.query_times),
            "avg_latency_ms": statistics.mean(self.query_times) * 1000,
            "median_latency_ms": statistics.median(self.query_times) * 1000,
            "p95_latency_ms": statistics.quantiles(self.query_times, n=20)[18] * 1000 if len(self.query_times) >= 20 else max(self.query_times) * 1000,
            "max_latency_ms": max(self.query_times) * 1000,
            "mode_distribution": self.query_counts,
        }

# Simulate production load
monitor = PerformanceMonitor()

test_queries = [
    ("mongodb atlas", "hybrid"),
    ("vector search", "vector"),
    ("full text search", "keyword"),
] * 10

print("Simulating production load...\n")

for query, mode in test_queries:
    latency = await measure_query_latency(query, mode, top_k=10)
    monitor.record_query(mode, latency)

stats = monitor.get_stats()

print("Performance Statistics:")
print(f"\nTotal queries: {stats['total_queries']}")
print(f"Average latency: {stats['avg_latency_ms']:.2f}ms")
print(f"Median latency: {stats['median_latency_ms']:.2f}ms")
print(f"P95 latency: {stats['p95_latency_ms']:.2f}ms")
print(f"Max latency: {stats['max_latency_ms']:.2f}ms")

print(f"\nMode distribution:")
for mode, count in stats['mode_distribution'].items():
    print(f"  {mode}: {count} ({count/stats['total_queries']*100:.1f}%)")

print("\nTarget SLAs:")
print("  - P50 (median): <200ms")
  print("  - P95: <500ms")
print("  - P99: <1000ms")

## 8. Production Checklist

### Search Optimization:
- [ ] Set appropriate `top_k` (default: 10)
- [ ] Configure `numCandidates` (automatic: top_k * 20)
- [ ] Enable `scoreDetails` for debugging
- [ ] Tune vector/text weights based on query type

### Caching:
- [ ] Implement query result caching
- [ ] Cache embeddings for unchanged content
- [ ] Set appropriate TTL (5-15 minutes)
- [ ] Monitor cache hit rate (target: >70%)

### MongoDB:
- [ ] Create appropriate indexes
- [ ] Monitor index usage with explain()
- [ ] Use connection pooling
- [ ] Enable read preference (secondary for reads)

### Cost Optimization:
- [ ] Use smaller embedding model when possible
- [ ] Batch API calls
- [ ] Implement request deduplication
- [ ] Monitor API usage

### Monitoring:
- [ ] Track query latency (P50, P95, P99)
- [ ] Monitor error rates
- [ ] Set up alerts for SLA violations
- [ ] Log slow queries (>1s)

## Conclusion

You've learned:
- Performance measurement and benchmarking
- Parameter optimization strategies
- Caching implementation
- Cost optimization techniques
- Production monitoring

**Next steps:**
- Deploy to production
- Monitor real-world performance
- Iterate based on metrics
- Scale as needed