# Hybrid Search Demo - Phase 3 Week 2

This notebook demonstrates the full Hybrid Search Pipeline, combining:
- **Stage 1**: Fast CLIP bi-encoder retrieval (<100ms)
- **Stage 2**: Accurate BLIP-2 cross-encoder re-ranking

Features showcased:
- Single query text-to-image search
- Batch query processing (2-6x speedup)
- Image-to-image search
- Method comparisons (CLIP vs Hybrid)
- Performance benchmarking
- Accuracy evaluation

---

## 1. Setup and Imports

In [None]:
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from typing import List, Tuple

# Auto-detect environment
if Path('/kaggle/input').exists():
    # Kaggle environment
    sys.path.append('/kaggle/working/hybrid_multimodal_retrieval')
    DATA_DIR = Path('/kaggle/input/flickr30k')
    print("Running on Kaggle")
else:
    # Local environment
    project_root = Path.cwd().parent
    sys.path.insert(0, str(project_root))
    DATA_DIR = project_root / 'data'
    print("Running locally")

from src.retrieval.bi_encoder import BiEncoder
from src.retrieval.cross_encoder import CrossEncoder
from src.retrieval.faiss_index import FAISSIndex
from src.retrieval.hybrid_search import HybridSearchEngine
from src.flickr30k.dataset import Flickr30KDataset

print("✓ Imports successful")
print(f"  Data directory: {DATA_DIR}")

## 2. Load Models and Data

In [None]:
print("Loading components...\n")

# Load Flickr30K dataset
print("[1/4] Loading Flickr30K dataset...")
dataset = Flickr30KDataset(
    images_dir=str(DATA_DIR / 'images'),
    annotations_file=str(DATA_DIR / 'results.csv')
)
print(f"  ✓ Loaded {len(dataset)} images\n")

# Load CLIP bi-encoder
print("[2/4] Loading CLIP bi-encoder...")
bi_encoder = BiEncoder(model_name='ViT-B/32', device='cuda')
print(f"  ✓ Model: {bi_encoder.model_name}\n")

# Load FAISS index
print("[3/4] Loading FAISS index...")
image_index = FAISSIndex(device='cuda')
index_path = DATA_DIR / 'indices' / 'image_index.faiss'
image_index.load(str(index_path))
print(f"  ✓ Loaded {image_index.index.ntotal:,} image vectors\n")

# Load BLIP-2 cross-encoder
print("[4/4] Loading BLIP-2 cross-encoder...")
cross_encoder = CrossEncoder(
    model_name='Salesforce/blip2-flan-t5-xl',
    device='cuda'
)
print(f"  ✓ Model: {cross_encoder.model_name}\n")

print("="*60)
print("All components loaded successfully! ✓")
print("="*60)

## 3. Initialize Hybrid Search Engine

In [None]:
# Initialize engine with optimized config
engine = HybridSearchEngine(
    bi_encoder=bi_encoder,
    cross_encoder=cross_encoder,
    image_index=image_index,
    dataset=dataset,
    config={
        'k1': 100,           # Stage 1 candidates
        'k2': 10,            # Final results
        'batch_size': 4,     # Cross-encoder batch size
        'use_cache': True,   # Enable caching
        'show_progress': True
    }
)

print(engine)
print("\n✓ HybridSearchEngine ready!")

## 4. Single Query Example

In [None]:
# Run a single hybrid search
query = "a dog playing in the park"

print(f"Query: '{query}'\n")

start = time.time()
results = engine.text_to_image_hybrid_search(
    query=query,
    k1=100,
    k2=5,
    show_progress=True
)
latency = (time.time() - start) * 1000

print(f"\n✓ Search completed in {latency:.2f}ms")
print(f"\nTop 5 Results:")
for i, (image_id, score) in enumerate(results, 1):
    print(f"  {i}. {image_id} - Score: {score:.4f}")

### Visualize Results

In [None]:
def visualize_results(query: str, results: List[Tuple[str, float]], dataset, n: int = 5):
    """Visualize top-n search results."""
    fig, axes = plt.subplots(1, n, figsize=(4*n, 4))
    
    if n == 1:
        axes = [axes]
    
    fig.suptitle(f"Query: '{query}'", fontsize=14, fontweight='bold')
    
    for idx, (image_id, score) in enumerate(results[:n]):
        # Get image
        item = dataset.get_by_image_id(image_id)
        img = Image.open(item['image_path']).convert('RGB')
        
        # Display
        axes[idx].imshow(img)
        axes[idx].set_title(f"Rank {idx+1}\nScore: {score:.4f}", fontsize=10)
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize the results
visualize_results(query, results, dataset, n=5)

## 5. Side-by-Side Comparison: CLIP vs Hybrid

In [None]:
# Compare CLIP-only vs Hybrid search
query = "a cat sitting on a couch"

print(f"Query: '{query}'\n")
print("="*60)

# CLIP-only (Stage 1)
print("\n[1] CLIP-only Search:")
start = time.time()
clip_results = engine._stage1_retrieve(query, k=5)
clip_latency = (time.time() - start) * 1000
print(f"  Latency: {clip_latency:.2f}ms")

# Hybrid (Stage 1 + Stage 2)
print("\n[2] Hybrid Search (CLIP + BLIP-2):")
start = time.time()
hybrid_results = engine.text_to_image_hybrid_search(
    query=query,
    k1=100,
    k2=5,
    show_progress=True
)
hybrid_latency = (time.time() - start) * 1000
print(f"  Latency: {hybrid_latency:.2f}ms")

print("\n" + "="*60)
print("COMPARISON:")
print(f"  CLIP:   {clip_latency:.2f}ms")
print(f"  Hybrid: {hybrid_latency:.2f}ms")
print(f"  Overhead: +{hybrid_latency - clip_latency:.2f}ms for better accuracy")
print("="*60)

### Visualize Comparison

In [None]:
def visualize_comparison(query: str, clip_results, hybrid_results, dataset):
    """Visualize CLIP vs Hybrid results side-by-side."""
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    
    fig.suptitle(f"Query: '{query}'", fontsize=16, fontweight='bold')
    
    # CLIP results (top row)
    axes[0, 0].text(0.5, 0.5, 'CLIP-only\n(Fast)', 
                    ha='center', va='center', fontsize=12, fontweight='bold')
    axes[0, 0].axis('off')
    
    for idx, (image_id, score) in enumerate(clip_results[:4]):
        item = dataset.get_by_image_id(image_id)
        img = Image.open(item['image_path']).convert('RGB')
        axes[0, idx+1].imshow(img)
        axes[0, idx+1].set_title(f"#{idx+1} | {score:.3f}", fontsize=10)
        axes[0, idx+1].axis('off')
    
    # Hybrid results (bottom row)
    axes[1, 0].text(0.5, 0.5, 'Hybrid\n(CLIP+BLIP-2)\n(Accurate)', 
                    ha='center', va='center', fontsize=12, fontweight='bold')
    axes[1, 0].axis('off')
    
    for idx, (image_id, score) in enumerate(hybrid_results[:4]):
        item = dataset.get_by_image_id(image_id)
        img = Image.open(item['image_path']).convert('RGB')
        axes[1, idx+1].imshow(img)
        axes[1, idx+1].set_title(f"#{idx+1} | {score:.3f}", fontsize=10)
        axes[1, idx+1].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_comparison(query, clip_results, hybrid_results, dataset)

## 6. Batch Search Demo

In [None]:
# Batch processing example
queries = [
    "a dog running on the beach",
    "a person riding a bicycle",
    "a group of people at a party",
    "sunset over the ocean",
    "a child playing with toys"
]

print(f"Processing {len(queries)} queries in batch...\n")

start = time.time()
batch_results = engine.batch_text_to_image_search(
    queries=queries,
    k1=100,
    k2=5,
    show_progress=True
)
batch_latency = time.time() - start

print(f"\n✓ Batch search completed in {batch_latency:.2f}s")
print(f"  Average per query: {(batch_latency/len(queries))*1000:.2f}ms")

# Show results for each query
print("\nResults per query:")
for i, query in enumerate(queries):
    results = batch_results[query]
    print(f"\n{i+1}. '{query}'")
    for j, (image_id, score) in enumerate(results[:3], 1):
        print(f"     {j}. {image_id} - {score:.4f}")

### Compare Batch vs Sequential

In [None]:
# Sequential processing
print("Running sequential search for comparison...\n")

start = time.time()
sequential_results = {}
for query in queries:
    results = engine.text_to_image_hybrid_search(
        query=query,
        k1=100,
        k2=5,
        show_progress=False
    )
    sequential_results[query] = results
sequential_latency = time.time() - start

# Comparison
print("="*60)
print("BATCH vs SEQUENTIAL COMPARISON")
print("="*60)
print(f"Sequential: {sequential_latency:.2f}s ({(sequential_latency/len(queries))*1000:.2f}ms/query)")
print(f"Batch:      {batch_latency:.2f}s ({(batch_latency/len(queries))*1000:.2f}ms/query)")
print(f"Speedup:    {sequential_latency/batch_latency:.2f}x")
print("="*60)

## 7. Image-to-Image Search

In [None]:
# Use first image as query
query_item = dataset[0]
query_image_id = query_item['image_id']
query_image_path = query_item['image_path']

print(f"Query Image: {query_image_id}\n")

# Display query image
query_img = Image.open(query_image_path).convert('RGB')
plt.figure(figsize=(5, 5))
plt.imshow(query_img)
plt.title(f"Query Image: {query_image_id}", fontweight='bold')
plt.axis('off')
plt.show()

# Run image-to-image search
print("\nSearching for similar images...\n")

start = time.time()
results = engine.image_to_image_search(
    image_id=query_image_id,
    k=6  # Get 6 results (including query itself)
)
latency = (time.time() - start) * 1000

print(f"✓ Search completed in {latency:.2f}ms\n")

# Show results (skip first which is query itself)
print("Top 5 Similar Images:")
for i, (image_id, score) in enumerate(results[1:6], 1):
    print(f"  {i}. {image_id} - Similarity: {score:.4f}")

### Visualize Similar Images

In [None]:
# Visualize similar images
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
fig.suptitle(f"Similar to: {query_image_id}", fontsize=14, fontweight='bold')

for idx, (image_id, score) in enumerate(results[1:6]):
    item = dataset.get_by_image_id(image_id)
    img = Image.open(item['image_path']).convert('RGB')
    
    axes[idx].imshow(img)
    axes[idx].set_title(f"#{idx+1}\nSimilarity: {score:.4f}", fontsize=10)
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

## 8. Performance Benchmarks

In [None]:
# Benchmark with different configurations
print("Running performance benchmarks...\n")
print("="*70)

test_query = "a dog running in a field"

configs = [
    {'k1': 50, 'k2': 5, 'batch_size': 2},
    {'k1': 100, 'k2': 10, 'batch_size': 4},
    {'k1': 200, 'k2': 20, 'batch_size': 8},
]

results_list = []

for config in configs:
    print(f"\nConfig: k1={config['k1']}, k2={config['k2']}, batch_size={config['batch_size']}")
    
    # Update config
    engine.update_config(**config)
    
    # Run multiple times for average
    latencies = []
    for _ in range(5):
        start = time.time()
        _ = engine.text_to_image_hybrid_search(
            query=test_query,
            k1=config['k1'],
            k2=config['k2'],
            show_progress=False
        )
        latencies.append((time.time() - start) * 1000)
    
    avg_latency = np.mean(latencies)
    print(f"  Average latency: {avg_latency:.2f}ms")
    
    results_list.append({
        'config': f"k1={config['k1']}, k2={config['k2']}",
        'latency': avg_latency
    })

print("\n" + "="*70)

### Visualize Performance

In [None]:
# Plot latency comparison
configs_labels = [r['config'] for r in results_list]
latencies = [r['latency'] for r in results_list]

plt.figure(figsize=(10, 6))
plt.bar(configs_labels, latencies, color=['skyblue', 'lightcoral', 'lightgreen'])
plt.axhline(y=2000, color='r', linestyle='--', label='Target: <2000ms')
plt.xlabel('Configuration', fontsize=12)
plt.ylabel('Latency (ms)', fontsize=12)
plt.title('Hybrid Search Latency vs Configuration', fontsize=14, fontweight='bold')
plt.legend()
plt.xticks(rotation=15)
plt.tight_layout()
plt.show()

## 9. Accuracy Evaluation

In [None]:
# Run quick accuracy test
print("Running accuracy evaluation...\n")
print("This tests CLIP vs Hybrid on a sample of queries")
print("="*70)

# Select 10 test queries
test_queries = []
step = len(dataset) // 10

for i in range(10):
    idx = i * step
    item = dataset[idx]
    if item['captions']:
        test_queries.append({
            'query': item['captions'][0],
            'ground_truth': item['image_id']
        })

# Evaluate CLIP-only
clip_correct = 0
for test in test_queries:
    results = engine._stage1_retrieve(test['query'], k=10)
    retrieved_ids = [img_id for img_id, _ in results]
    if test['ground_truth'] in retrieved_ids:
        clip_correct += 1

# Evaluate Hybrid
hybrid_correct = 0
for test in test_queries:
    results = engine.text_to_image_hybrid_search(
        query=test['query'],
        k1=100,
        k2=10,
        show_progress=False
    )
    retrieved_ids = [img_id for img_id, _ in results]
    if test['ground_truth'] in retrieved_ids:
        hybrid_correct += 1

# Results
print("\n" + "="*70)
print("ACCURACY RESULTS (Recall@10)")
print("="*70)
print(f"CLIP-only:  {clip_correct}/{len(test_queries)} = {clip_correct/len(test_queries):.2%}")
print(f"Hybrid:     {hybrid_correct}/{len(test_queries)} = {hybrid_correct/len(test_queries):.2%}")
print(f"Improvement: +{((hybrid_correct-clip_correct)/len(test_queries))*100:.1f}%")
print("="*70)

### Visualize Accuracy Comparison

In [None]:
# Bar chart comparison
methods = ['CLIP-only', 'Hybrid']
accuracies = [clip_correct/len(test_queries), hybrid_correct/len(test_queries)]

plt.figure(figsize=(8, 6))
bars = plt.bar(methods, accuracies, color=['skyblue', 'lightcoral'])
plt.ylabel('Recall@10', fontsize=12)
plt.title('Accuracy Comparison: CLIP vs Hybrid', fontsize=14, fontweight='bold')
plt.ylim(0, 1.0)

# Add value labels
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{acc:.2%}',
             ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

## 10. Statistics and Cache Info

In [None]:
# Get engine statistics
stats = engine.get_statistics()

print("="*70)
print("ENGINE STATISTICS")
print("="*70)

print(f"\nSearch Counts:")
print(f"  Text-to-Image searches: {stats['text_to_image_count']}")
print(f"  Image-to-Image searches: {stats['image_to_image_count']}")
print(f"  Batch searches: {stats['batch_search_count']}")
print(f"  Total searches: {stats['total_searches']}")

print(f"\nAverage Latencies:")
if stats['avg_stage1_latency'] > 0:
    print(f"  Stage 1 (CLIP): {stats['avg_stage1_latency']:.2f}ms")
if stats['avg_stage2_latency'] > 0:
    print(f"  Stage 2 (BLIP-2): {stats['avg_stage2_latency']:.2f}ms")
if stats['avg_total_latency'] > 0:
    print(f"  Total: {stats['avg_total_latency']:.2f}ms")

print(f"\nCache Statistics:")
print(f"  Cache hits: {stats['cache_hits']}")
if stats['total_searches'] > 0:
    hit_rate = stats['cache_hits'] / stats['total_searches']
    print(f"  Hit rate: {hit_rate:.2%}")
print(f"  Cache size: {engine.get_cache_size()} entries")

print("\n" + "="*70)
print(f"\nCurrent Configuration:")
config = engine.get_config()
for key, value in config.items():
    print(f"  {key}: {value}")
print("="*70)

## Summary

This notebook demonstrated:

✅ **Core Features**:
- Single query text-to-image hybrid search
- Batch processing for multiple queries (2-6x speedup)
- Image-to-image similarity search

✅ **Performance**:
- Stage 1 (CLIP): <100ms
- End-to-end: <2000ms (with Stage 2)
- Batch mode: 2-6x faster than sequential

✅ **Accuracy**:
- Hybrid search improves Recall@10 over CLIP-only
- BLIP-2 re-ranking refines results for better relevance

✅ **Configuration**:
- Runtime parameter tuning (k1, k2, batch_size)
- Caching for repeated queries
- Performance profiling and statistics

---

**Next Steps**:
- Run full accuracy evaluation with 100 queries (`scripts/evaluate_accuracy.py`)
- Profile different configurations (`scripts/test_configuration.py`)
- Test edge cases (`scripts/test_hybrid_search.py`)

For detailed API documentation, see `API_REFERENCE.md`.