In [16]:
import weaviate
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import re
from typing import Dict, List, Set
import torch
import time

# Initialize SentenceTransformer for manual setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
clip_model = SentenceTransformer('clip-ViT-B-32', device=device, cache_folder='./model_cache')

# Define queries
queries = ["bag", "dog", "park", "red car", "sunset"]

def is_relevant(caption: str, query: str) -> bool:
    """
    Check if the query term appears in the caption (case-insensitive).
    You can enhance this with synonyms, lemmatization, or visual inspection.
    """
    query_terms = query.lower().split()  # Split query into terms (e.g., "red car" -> ["red", "car"])
    caption_lower = caption.lower()
    return any(re.search(r'\b' + re.escape(term) + r'\b', caption_lower) for term in query_terms)


def generate_ground_truth(client: weaviate.Client, collection_name: str, queries: List[str]) -> Dict[str, Set[str]]:
    """
    Search all captions in the collection to find relevant image_ids for each query.
    Returns a dict mapping query to set of relevant image_ids.
    """
    ground_truth = {query: set() for query in queries}
    collection = client.collections.get(collection_name)
    
    for obj in tqdm(collection.iterator(), desc=f"Generating ground truth for {collection_name}"):
        image_id = obj.properties["image_id"]
        captions = obj.properties["captions"]
        for query in queries:
            if is_relevant(captions, query):
                ground_truth[query].add(image_id)

    for query, relevant_ids in ground_truth.items():
        print(f"Query: {query}, Relevant Images: {len(relevant_ids)}")
    
    return ground_truth


def calculate_metrics(retrieved_ids: List[str], relevant_ids: Set[str], k: int = 5) -> tuple[float, float, float]:
    """
    Calculate Precision@k, Recall@k, and Average Precision for a single query.
    - retrieved_ids: List of retrieved image_ids (top-k results).
    - relevant_ids: Set of ground truth relevant image_ids.
    - k: Number of results to consider (e.g., 5).
    Returns: (precision, recall, average_precision)
    """
    # Truncate to top-k results
    # retrieved_ids = retrieved_ids[:k]
    
    # Count relevant items in retrieved results
    # relevant_retrieved = sum(1 for image_id in retrieved_ids if image_id in relevant_ids)
    relevant = set(relevant_ids)
    retrieved = set(retrieved_ids[:k])
    true_positives = len(relevant & retrieved)

    
    # Precision@k
    precision = true_positives / k if k > 0 else 0.0
    
    # Recall@k
    recall = true_positives / len(relevant_ids) if relevant_ids else 0.0
    
    average_precision = 0.0
    relevant_count = 0
    for rank, image_id in enumerate(retrieved_ids, 1):
        if image_id in relevant_ids:
            relevant_count += 1
            precision_at_rank = relevant_count / rank
            average_precision += precision_at_rank
    average_precision = average_precision / len(relevant_ids) if relevant_ids else 0.0
    
    return precision, recall, average_precision


def evaluate_setup(client: weaviate.Client, collection_name: str, queries: List[str], ground_truth: Dict[str, Set[str]], is_manual: bool = False) -> Dict[str, dict]:
    """
    Run queries and calculate metrics for a given setup.
    - is_manual: True for SentenceTransformer (near_vector), False for multi2vec-clip (near_text).
    Returns: Dict mapping query to metrics (precision, recall, ap, latency, results).
    """
    collection = client.collections.get(collection_name)
    response = collection.aggregate.over_all(total_count=True)
    print(f"{collection_name} collection size is: {response.total_count}")
    results = {}
    
    for query in queries:
        try:
            start_time = time.time()
            if is_manual:
                # Manual setup: Compute query vector and use near_vector
                query_vector = clip_model.encode(query, convert_to_numpy=True).tolist()
                response = collection.query.near_vector(near_vector=query_vector, limit=5)
            else:
                # Multi2vec-clip: Use near_text
                response = collection.query.near_text(query=query, limit=5)
            latency = time.time() - start_time
            
            retrieved_ids = [obj.properties["image_id"] for obj in response.objects]
            retrieved_captions = [obj.properties["captions"] for obj in response.objects]
            
            precision, recall, ap = calculate_metrics(retrieved_ids, ground_truth[query], k=5)
            
            results[query] = {
                "precision": precision,
                "recall": recall,
                "ap": ap,
                "latency": latency,
                "image_ids": retrieved_ids,
                "captions": retrieved_captions
            }
        
        except Exception as e:
            print(f"Error processing query '{query}' in {collection_name}: {str(e)}")
            results[query] = {
                "precision": 0.0,
                "recall": 0.0,
                "ap": 0.0,
                "latency": 0.0,
                "image_ids": [],
                "captions": []
            }
    
    return results


def main():
    client = weaviate.connect_to_local()
    
    # Generate ground truth for both collections
    ground_truth_manual = generate_ground_truth(client, "Flickr30k_manual", queries)
    ground_truth_multi2vec = generate_ground_truth(client, "Flickr30k_multi2vec", queries)
    
    # Evaluate both setups
    results_manual = evaluate_setup(client, "Flickr30k_manual", queries, ground_truth_manual, is_manual=True)
    results_multi2vec = evaluate_setup(client, "Flickr30k_multi2vec", queries, ground_truth_multi2vec, is_manual=False)
    
    # Summarize results
    print("\nMulti2vec Results:")
    avg_precision_multi2vec = avg_recall_multi2vec = avg_ap_multi2vec = avg_latency_multi2vec = 0.0
    for query, metrics in results_multi2vec.items():
        print(f"Query: {query}, Precision@5: {metrics['precision']:.3f}, Recall@5: {metrics['recall']:.3f}, AP@5: {metrics['ap']:.3f}")
        avg_precision_multi2vec += metrics['precision']
        avg_recall_multi2vec += metrics['recall']
        avg_ap_multi2vec += metrics['ap']
        avg_latency_multi2vec += metrics['latency']

    # avg_precision_multi2vec /= len(queries)
    # avg_recall_multi2vec /= len(queries)
    # avg_ap_multi2vec /= len(queries)
    # avg_latency_multi2vec /= len(queries)
    
    # print(f"\nAverage Metrics (Multi2vec-clip):")
    # print(f"  Avg Precision@5: {avg_precision_multi2vec:.3f}")
    # print(f"  Avg Recall@5: {avg_recall_multi2vec:.3f}")
    # print(f"  mAP: {avg_ap_multi2vec:.3f}")
    # print(f"  Avg Latency: {avg_latency_multi2vec:.3f}s")

    print("\nManual Results:")
    avg_precision_manual = avg_recall_manual = avg_ap_manual = avg_latency_manual = 0.0
    for query, metrics in results_manual.items():
        print(f"Query: {query}, Precision@5: {metrics['precision']:.3f}, Recall@5: {metrics['recall']:.3f}, AP@5: {metrics['ap']:.3f}")
        avg_precision_manual += metrics['precision']
        avg_recall_manual += metrics['recall']
        avg_ap_manual += metrics['ap']
        avg_latency_manual += metrics['latency']

    # avg_precision_manual /= len(queries)
    # avg_recall_manual /= len(queries)
    # avg_ap_manual /= len(queries)
    # avg_latency_manual /= len(queries)
    
    # print(f"\nAverage Metrics (Manual):")
    # print(f"  Avg Precision@5: {avg_precision_manual:.3f}")
    # print(f"  Avg Recall@5: {avg_recall_manual:.3f}")
    # print(f"  mAP: {avg_ap_manual:.3f}")
    # print(f"  Avg Latency: {avg_latency_manual:.3f}s")

    
    client.close()


In [17]:
if __name__ == "__main__":
    main()

Generating ground truth for Flickr30k_manual: 598it [00:00, 1336.04it/s]


Query: bag, Relevant Images: 8
Query: dog, Relevant Images: 36
Query: park, Relevant Images: 12
Query: red car, Relevant Images: 99
Query: sunset, Relevant Images: 2


Generating ground truth for Flickr30k_multi2vec: 598it [00:00, 2466.99it/s]


Query: bag, Relevant Images: 8
Query: dog, Relevant Images: 36
Query: park, Relevant Images: 12
Query: red car, Relevant Images: 99
Query: sunset, Relevant Images: 2
Flickr30k_manual collection size is: 598
Flickr30k_multi2vec collection size is: 598

Multi2vec Results:
Query: bag, Precision@5: 0.200, Recall@5: 0.125, AP@5: 0.031
Query: dog, Precision@5: 1.000, Recall@5: 0.139, AP@5: 0.139
Query: park, Precision@5: 0.000, Recall@5: 0.000, AP@5: 0.000
Query: red car, Precision@5: 0.200, Recall@5: 0.010, AP@5: 0.002
Query: sunset, Precision@5: 0.000, Recall@5: 0.000, AP@5: 0.000

Manual Results:
Query: bag, Precision@5: 0.000, Recall@5: 0.000, AP@5: 0.000
Query: dog, Precision@5: 0.800, Recall@5: 0.111, AP@5: 0.106
Query: park, Precision@5: 0.000, Recall@5: 0.000, AP@5: 0.000
Query: red car, Precision@5: 0.400, Recall@5: 0.020, AP@5: 0.014
Query: sunset, Precision@5: 0.000, Recall@5: 0.000, AP@5: 0.000
