In [None]:
# FAST Hybrid Search Evaluation Notebook
# This notebook replicates scripts/evaluate_accuracy.py in an interactive setting
# for quick smoke-test comparison between CLIP-only baseline and Hybrid Search (CLIP + BLIP-2)

import os
import sys
from pathlib import Path

# Ensure project root is on sys.path
project_root = Path("..").resolve()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print("Project root:", project_root)
print("In sys.path:", any(str(project_root) == p for p in sys.path))

In [None]:
# Imports and constants
import time
import json
from typing import List, Dict, Any

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from src.retrieval.bi_encoder import BiEncoder
from src.retrieval.cross_encoder import CrossEncoder
from src.retrieval.faiss_index import FAISSIndex
from src.retrieval.hybrid_search import HybridSearchEngine
from src.flickr30k.dataset import Flickr30KDataset

# FAST MODE constants
FAST_SEED = 2025
FAST_N = 25
FAST_K1 = 30
FAST_K2 = 10

# Detect data directory (Kaggle vs local)
if Path('/kaggle/input').exists():
    DATA_DIR = Path('/kaggle/input/flickr30k/data')
    print("Running on Kaggle; DATA_DIR =", DATA_DIR)
else:
    DATA_DIR = project_root / 'data'
    print("Running locally; DATA_DIR =", DATA_DIR)

In [None]:
# Utility and evaluation functions (adapted from scripts/evaluate_accuracy.py)

def load_components():
    """Load dataset, encoders, and FAISS index (FAST MODE)."""
    print("\n" + "=" * 70)
    print("LOADING COMPONENTS")
    print("=" * 70)
    
    start_time = time.time()
    
    print("\n[1/4] Loading Flickr30K dataset...")
    dataset = Flickr30KDataset(
        images_dir=str(DATA_DIR / 'images'),
        captions_file=str(DATA_DIR / 'results.csv')
    )
    print(f"  ✓ Loaded {len(dataset)} images")
    
    print("\n[2/4] Loading CLIP bi-encoder...")
    bi_encoder = BiEncoder(model_name='ViT-B/32', device='cuda')
    print(f"  ✓ Model: {bi_encoder.model_name}")
    
    print("\n[3/4] Loading FAISS index...")
    image_index = FAISSIndex(device='cuda')
    index_path = DATA_DIR / 'indices' / 'image_index.faiss'
    image_index.load(str(index_path))
    print(f"  ✓ Loaded {image_index.index.ntotal:,} vectors")
    
    print("\n[4/4] Loading BLIP-2 cross-encoder...")
    cross_encoder = CrossEncoder(
        model_name='Salesforce/blip2-opt-2.7b',
        device='cuda',
        use_fp16=True
    )
    print(f"  ✓ Model: {cross_encoder.model_name}")
    
    load_time = time.time() - start_time
    print(f"\n{'=' * 70}")
    print(f"Components loaded in {load_time:.2f}s")
    print(f"{'=' * 70}\n")
    
    return bi_encoder, cross_encoder, image_index, dataset


def select_test_queries(dataset: Flickr30KDataset, n: int = FAST_N, seed: int = FAST_SEED) -> List[Dict[str, Any]]:
    """Reproducibly select test queries from the dataset."""
    print("\n" + "=" * 70)
    print(f"SELECTING {n} TEST QUERIES (FAST MODE)")
    print("=" * 70)
    
    unique_images = dataset.get_unique_images()
    rng = np.random.default_rng(seed)
    n = min(n, len(unique_images))
    chosen = rng.choice(unique_images, size=n, replace=False)
    
    test_queries = []
    for image_id in chosen:
        captions = dataset.get_captions(image_id)
        if not captions:
            continue
        test_queries.append({
            'query': captions[0],
            'ground_truth': image_id,
            'alternatives': captions[1:] if len(captions) > 1 else []
        })
    
    print(f"\n✓ Selected {len(test_queries)} test queries")
    return test_queries


def ndcg_at_k(rank: int, k: int = 10) -> float:
    """Calculate nDCG@k for a single relevant item at given rank."""
    if rank is None or rank > k:
        return 0.0
    # IDCG is 1.0 for a single relevant item
    return 1.0 / np.log2(rank + 1)


def calculate_metrics(results: List[Dict[str, Any]], k: int) -> Dict[str, Any]:
    """Compute Recall@k, MRR, nDCG@10, MAP and latency statistics."""
    n_queries = len(results)
    
    recall_at_1 = sum(1 for r in results if r['rank'] == 1) / n_queries
    recall_at_5 = sum(1 for r in results if r['rank'] and r['rank'] <= 5) / n_queries
    recall_at_10 = sum(1 for r in results if r['rank'] and r['rank'] <= 10) / n_queries
    
    reciprocal_ranks = [(1.0 / r['rank']) if r['rank'] else 0.0 for r in results]
    mrr = float(np.mean(reciprocal_ranks))
    
    ndcg10 = float(np.mean([ndcg_at_k(r['rank'], k=10) for r in results]))
    
    # For single relevant item, MAP equals MRR
    map_score = mrr
    
    latencies = [r['latency'] for r in results]
    latency_mean = float(np.mean(latencies)) if latencies else 0.0
    latency_median = float(np.median(latencies)) if latencies else 0.0
    
    return {
        'recall@1': recall_at_1,
        'recall@5': recall_at_5,
        'recall@10': recall_at_10,
        'mrr': mrr,
        'ndcg@10': ndcg10,
        'map': map_score,
        'latencies': {
            'mean': latency_mean,
            'median': latency_median,
        },
        'n_queries': n_queries,
    }

In [None]:
# Load dataset, models, and initialize HybridSearchEngine
bi_encoder, cross_encoder, image_index, dataset = load_components()

print("\n" + "=" * 60)
print("Initializing HybridSearchEngine (FAST MODE)")
print("=" * 60)

engine = HybridSearchEngine(
    bi_encoder=bi_encoder,
    cross_encoder=cross_encoder,
    image_index=image_index,
    dataset=dataset,
    config={
        'k1': FAST_K1,
        'k2': FAST_K2,
        'batch_size': 8,
        'use_cache': False,
        'show_progress': False,
        'fusion_method': 'weighted',
        'stage1_weight': 0.3,
        'stage2_weight': 0.7,
    }
)

print("  ✓ Engine initialized")
print(f"  ✓ Images in index: {image_index.index.ntotal:,}")

In [None]:
# Select test queries
test_queries = select_test_queries(dataset, n=FAST_N, seed=FAST_SEED)

# Preview first 5 queries
preview_rows = []
for tq in test_queries[:5]:
    preview_rows.append({
        'query': tq['query'],
        'ground_truth': tq['ground_truth'],
        'n_alternatives': len(tq['alternatives']),
    })

pd.DataFrame(preview_rows)

In [None]:
# CLIP-only evaluation (baseline)
results_clip: List[Dict[str, Any]] = []

print("\n" + "=" * 70)
print("EVALUATING CLIP-ONLY SEARCH (Baseline)")
print("=" * 70)
print(f"Evaluating {len(test_queries)} queries with k={FAST_K2}")

for tq in test_queries:
    query = tq['query']
    ground_truth = tq['ground_truth']
    
    start = time.time()
    # Use engine's stage-1 CLIP + FAISS retrieval
    search_results = engine._stage1_retrieve(query, k1=FAST_K2)
    latency_ms = (time.time() - start) * 1000
    
    retrieved_ids = [img_id for img_id, score in search_results]
    
    ground_truth_rank = None
    if ground_truth in retrieved_ids:
        ground_truth_rank = retrieved_ids.index(ground_truth) + 1
    
    results_clip.append({
        'query': query,
        'ground_truth': ground_truth,
        'retrieved': retrieved_ids,
        'rank': ground_truth_rank,
        'latency': latency_ms,
    })

metrics_clip = calculate_metrics(results_clip, k=FAST_K2)

print("\nCLIP-only metrics (FAST MODE):")
for k, v in metrics_clip.items():
    if k == 'latencies':
        print(f"  {k}: mean={v['mean']:.2f} ms, median={v['median']:.2f} ms")
    else:
        print(f"  {k}: {v}")

In [None]:
# Hybrid evaluation (CLIP + BLIP-2)
results_hybrid: List[Dict[str, Any]] = []

print("\n" + "=" * 70)
print("EVALUATING HYBRID SEARCH (CLIP + BLIP-2)")
print("=" * 70)
print(f"Evaluating {len(test_queries)} queries with k1={FAST_K1}, k2={FAST_K2}")

for tq in test_queries:
    query = tq['query']
    ground_truth = tq['ground_truth']
    
    start = time.time()
    search_results = engine.text_to_image_hybrid_search(
        query=query,
        k1=FAST_K1,
        k2=FAST_K2,
        show_progress=False,
    )
    latency_ms = (time.time() - start) * 1000
    
    retrieved_ids = [img_id for img_id, score in search_results]
    
    ground_truth_rank = None
    if ground_truth in retrieved_ids:
        ground_truth_rank = retrieved_ids.index(ground_truth) + 1
    
    results_hybrid.append({
        'query': query,
        'ground_truth': ground_truth,
        'retrieved': retrieved_ids,
        'rank': ground_truth_rank,
        'latency': latency_ms,
    })

metrics_hybrid = calculate_metrics(results_hybrid, k=FAST_K2)

print("\nHybrid search metrics (FAST MODE):")
for k, v in metrics_hybrid.items():
    if k == 'latencies':
        print(f"  {k}: mean={v['mean']:.2f} ms, median={v['median']:.2f} ms")
    else:
        print(f"  {k}: {v}")

In [None]:
# Aggregate and compare metrics
rows = []
for name, m in [('clip_only', metrics_clip), ('hybrid', metrics_hybrid)]:
    rows.append({
        'method': name,
        'recall@1': m['recall@1'],
        'recall@5': m['recall@5'],
        'recall@10': m['recall@10'],
        'mrr': m['mrr'],
        'ndcg@10': m['ndcg@10'],
        'map': m['map'],
        'latency_mean_ms': m['latencies']['mean'],
        'latency_median_ms': m['latencies']['median'],
    })

metrics_df = pd.DataFrame(rows).set_index('method')
display(metrics_df)

# Calculate deltas (hybrid - clip_only)
delta = metrics_df.loc['hybrid'] - metrics_df.loc['clip_only']
print("\nDeltas (hybrid - clip_only):")
display(delta.to_frame(name='delta'))

In [None]:
# Visualize comparison
sns.set(style="whitegrid")

# Bar chart for Recall@10
plt.figure(figsize=(6, 4))
sns.barplot(
    x=metrics_df.index,
    y=metrics_df['recall@10'],
    palette=['C0', 'C1']
)
plt.title('Recall@10 Comparison (FAST MODE)')
plt.ylabel('Recall@10')
plt.xlabel('Method')
for i, val in enumerate(metrics_df['recall@10']):
    plt.text(i, val + 0.01, f"{val:.2f}", ha='center')
plt.ylim(0, min(1.0, metrics_df['recall@10'].max() + 0.1))
plt.show()

# Bar chart for Mean Latency
plt.figure(figsize=(6, 4))
sns.barplot(
    x=metrics_df.index,
    y=metrics_df['latency_mean_ms'],
    palette=['C0', 'C1']
)
plt.title('Mean Latency Comparison (FAST MODE)')
plt.ylabel('Mean latency (ms)')
plt.xlabel('Method')
for i, val in enumerate(metrics_df['latency_mean_ms']):
    plt.text(i, val + 1, f"{val:.1f}", ha='center')
plt.show()