In [36]:
import numpy as np
import faiss
import pickle
import time
from collections import defaultdict

In [37]:
# Create Filter Helper Class, it can also be imported from metadata_indexing file
class MetadataFilter:
    """Helper class for efficient metadata filtering"""
    
    def __init__(self, categorical_indexes, numeric_indexes, n_samples):
        self.categorical_indexes = categorical_indexes
        self.numeric_indexes = numeric_indexes
        self.n_samples = n_samples
    
    def apply_filter(self, filter_conditions):
        """
        Apply filter conditions and return a boolean bitmap
        
        filter_conditions format:
        {
            'artist': ['Coldplay', 'Radiohead'],
            'year': (2000, 2020),
            'tempo': (80, 150),
            'mode': [1]
        }
        """
        # Start with all True
        bitmap = np.ones(self.n_samples, dtype=bool)
        
        for key, value in filter_conditions.items():
            if key in self.categorical_indexes:
                # Categorical filter
                valid_indices = set()
                if isinstance(value, list):
                    for val in value:
                        if val in self.categorical_indexes[key]:
                            valid_indices.update(self.categorical_indexes[key][val])
                else:
                    if value in self.categorical_indexes[key]:
                        valid_indices.update(self.categorical_indexes[key][value])
                
                temp_bitmap = np.zeros(self.n_samples, dtype=bool)
                temp_bitmap[list(valid_indices)] = True
                bitmap &= temp_bitmap
            
            elif key in self.numeric_indexes:
                # Numeric range filter
                if isinstance(value, tuple) and len(value) == 2:
                    min_val, max_val = value
                    sorted_indices = np.array(self.numeric_indexes[key]['sorted_indices'])
                    sorted_values = np.array(self.numeric_indexes[key]['sorted_values'])
                    
                    valid_indices = indices_in_range(sorted_indices, sorted_values, min_val, max_val)
                    
                    temp_bitmap = np.zeros(self.n_samples, dtype=bool)
                    temp_bitmap[valid_indices] = True
                    bitmap &= temp_bitmap
        
        return bitmap
    
    def get_valid_indices(self, filter_conditions):
        """Return list of valid indices after filtering"""
        bitmap = self.apply_filter(filter_conditions)
        return np.where(bitmap)[0].tolist()
    
    def get_selectivity(self, filter_conditions):
        """Return selectivity (fraction of data passing filter)"""
        bitmap = self.apply_filter(filter_conditions)
        return bitmap.sum() / self.n_samples

In [38]:
# define indice_in_range, it can also be imported from metadata_indexing file
def indices_in_range(sorted_indices, sorted_values, min_val, max_val):
    """Binary search to find indices within a range"""
    left = np.searchsorted(sorted_values, min_val, side='left')
    right = np.searchsorted(sorted_values, max_val, side='right')
    return sorted_indices[left:right].tolist()

In [39]:
# Load Required Data
# ============================================
print("Loading data...")
embeddings = np.load("data/spotify_vectors_10d.npy").astype('float32')
index_ivf = faiss.read_index("data/index_ivf_flat.faiss")
index_flat = faiss.read_index("data/index_flat_l2.faiss")

with open("data/metadata_filter.pkl", "rb") as f:
    metadata_filter = pickle.load(f)

print(f"Loaded {embeddings.shape[0]} embeddings")
print(f"IVF index with {index_ivf.nlist} clusters")

Loading data...
Loaded 169776 embeddings
IVF index with 412 clusters


In [40]:
# 1. Cluster-Level Metadata Precomputation
# ============================================
print("\n=== Precomputing Cluster Metadata ===")

def compute_cluster_assignments(index_ivf, embeddings):
    """Compute which cluster each vector belongs to"""
    quantizer = faiss.downcast_index(index_ivf.quantizer)
    _, cluster_ids = quantizer.search(embeddings, 1)
    return cluster_ids.flatten()

cluster_assignments = compute_cluster_assignments(index_ivf, embeddings)

# Build cluster -> track indices mapping
cluster_to_tracks = defaultdict(list)
for track_idx, cluster_id in enumerate(cluster_assignments):
    cluster_to_tracks[cluster_id].append(track_idx)

print(f"Computed cluster assignments for {len(cluster_to_tracks)} clusters")


=== Precomputing Cluster Metadata ===
Computed cluster assignments for 412 clusters


In [41]:
# 2. Build Cluster-Level Filter Bitmaps
# ============================================

def build_cluster_filter_bitmaps(cluster_to_tracks, metadata_filter, filter_conditions):
    """
    For each cluster, compute a bitmap indicating whether it contains
    any tracks that satisfy the filter conditions.
    
    Returns:
        cluster_has_valid_tracks: dict mapping cluster_id -> bool
        tracks_per_cluster_passing: dict mapping cluster_id -> list of valid track indices
    """
    cluster_has_valid_tracks = {}
    tracks_per_cluster_passing = {}
    
    # Get global valid tracks bitmap
    global_valid_bitmap = metadata_filter.apply_filter(filter_conditions)
    
    for cluster_id, track_indices in cluster_to_tracks.items():
        # Check which tracks in this cluster pass the filter
        valid_tracks_in_cluster = [idx for idx in track_indices if global_valid_bitmap[idx]]
        
        cluster_has_valid_tracks[cluster_id] = len(valid_tracks_in_cluster) > 0
        tracks_per_cluster_passing[cluster_id] = valid_tracks_in_cluster
    
    return cluster_has_valid_tracks, tracks_per_cluster_passing

In [42]:
# 3. Hybrid Search with Predicate Pushdown
# ============================================

class HybridSearcher:
    """
    Hybrid search that integrates metadata filtering with FAISS traversal
    """
    
    def __init__(self, index_ivf, index_flat, embeddings, cluster_to_tracks, metadata_filter):
        self.index_ivf = index_ivf
        self.index_flat = index_flat
        self.embeddings = embeddings
        self.cluster_to_tracks = cluster_to_tracks
        self.metadata_filter = metadata_filter
        self.n_samples = len(embeddings)
        
        # Statistics
        self.stats = {
            'total_queries': 0,
            'clusters_probed': 0,
            'clusters_skipped': 0,
            'distance_computations': 0,
            'total_latency': 0.0
        }
    
    def search_with_predicate_pushdown(self, query, k, filter_conditions, nprobe=20):
        """
        Perform hybrid search with predicate pushdown
        
        Args:
            query: Query embedding (1, d)
            k: Number of results to return
            filter_conditions: Metadata filter dict
            nprobe: Number of clusters to probe
            
        Returns:
            distances: Top-k distances
            indices: Top-k track indices
            stats: Search statistics
        """
        start_time = time.time()
        
        # Build cluster-level filter bitmaps
        cluster_has_valid, tracks_per_cluster = build_cluster_filter_bitmaps(
            self.cluster_to_tracks, self.metadata_filter, filter_conditions
        )
        
        # Find nearest clusters
        quantizer = faiss.downcast_index(self.index_ivf.quantizer)
        _, nearest_clusters = quantizer.search(query, nprobe * 2)  # Get more candidates
        nearest_clusters = nearest_clusters.flatten()
        
        # Collect candidates with predicate pushdown
        candidates = []
        clusters_probed = 0
        clusters_skipped = 0
        distance_computations = 0
        
        for cluster_id in nearest_clusters:
            cluster_id = int(cluster_id)
            
            # Predicate pushdown: skip clusters without valid tracks
            if cluster_id not in cluster_has_valid or not cluster_has_valid[cluster_id]:
                clusters_skipped += 1
                continue
            
            clusters_probed += 1
            
            # Get valid tracks from this cluster
            valid_tracks = tracks_per_cluster[cluster_id]
            
            # Compute distances only for valid tracks
            for track_idx in valid_tracks:
                dist = np.linalg.norm(query[0] - self.embeddings[track_idx])
                candidates.append((dist, track_idx))
                distance_computations += 1
            
            # Early stopping if we have enough candidates
            if len(candidates) >= k * 3:
                break
        
        # Sort and get top-k
        candidates.sort(key=lambda x: x[0])
        top_k = candidates[:k]
        
        if len(top_k) < k:
            # Pad with dummy results if not enough valid results
            while len(top_k) < k:
                top_k.append((float('inf'), -1))
        
        distances = np.array([d for d, _ in top_k])
        indices = np.array([idx for _, idx in top_k])
        
        latency = time.time() - start_time
        
        # Update statistics
        self.stats['total_queries'] += 1
        self.stats['clusters_probed'] += clusters_probed
        self.stats['clusters_skipped'] += clusters_skipped
        self.stats['distance_computations'] += distance_computations
        self.stats['total_latency'] += latency
        
        search_stats = {
            'latency': latency,
            'clusters_probed': clusters_probed,
            'clusters_skipped': clusters_skipped,
            'distance_computations': distance_computations,
            'candidates_found': len(candidates),
            'nprobe_requested': nprobe
        }
        
        return distances, indices, search_stats
    
    def search_baseline_postfilter(self, query, k, filter_conditions, nprobe=20, candidate_multiplier=10):
        """
        Baseline: Retrieve large candidate pool, then post-filter
        """
        start_time = time.time()
        
        # Retrieve more candidates
        self.index_ivf.nprobe = nprobe
        k_candidates = k * candidate_multiplier
        D, I = self.index_ivf.search(query, k_candidates)
        
        # Apply metadata filter
        valid_bitmap = self.metadata_filter.apply_filter(filter_conditions)
        
        # Filter candidates
        valid_results = []
        for dist, idx in zip(D[0], I[0]):
            if idx >= 0 and valid_bitmap[idx]:
                valid_results.append((dist, idx))
                if len(valid_results) >= k:
                    break
        
        # Pad if needed
        while len(valid_results) < k:
            valid_results.append((float('inf'), -1))
        
        distances = np.array([d for d, _ in valid_results])
        indices = np.array([idx for _, idx in valid_results])
        
        latency = time.time() - start_time
        
        stats = {
            'latency': latency,
            'candidates_retrieved': k_candidates,
            'distance_computations': k_candidates,  # All candidates computed
            'valid_found': len([idx for idx in indices if idx >= 0])
        }
        
        return distances, indices, stats
    
    def search_baseline_prefilter(self, query, k, filter_conditions):
        """
        Baseline: Pre-filter, then exact search on filtered subset
        """
        start_time = time.time()
        
        # Get valid indices
        valid_indices = self.metadata_filter.get_valid_indices(filter_conditions)
        
        if len(valid_indices) == 0:
            # No valid tracks
            return np.full(k, float('inf')), np.full(k, -1), {'latency': time.time() - start_time}
        
        # Extract valid embeddings
        valid_embeddings = self.embeddings[valid_indices]
        
        # Compute distances
        distances_all = np.linalg.norm(valid_embeddings - query[0], axis=1)
        
        # Get top-k
        if len(distances_all) < k:
            top_k_local = np.arange(len(distances_all))
            padding_needed = k - len(distances_all)
            distances = np.concatenate([distances_all, np.full(padding_needed, float('inf'))])
            indices = np.concatenate([np.array(valid_indices), np.full(padding_needed, -1)])
        else:
            top_k_local = np.argpartition(distances_all, k)[:k]
            top_k_local = top_k_local[np.argsort(distances_all[top_k_local])]
            distances = distances_all[top_k_local]
            indices = np.array(valid_indices)[top_k_local]
        
        latency = time.time() - start_time
        
        stats = {
            'latency': latency,
            'filtered_size': len(valid_indices),
            'distance_computations': len(valid_indices)
        }
        
        return distances, indices, stats
    
    def get_ground_truth(self, query, k, filter_conditions):
        """Compute ground truth using brute force on filtered data"""
        valid_indices = self.metadata_filter.get_valid_indices(filter_conditions)
        
        if len(valid_indices) == 0:
            return np.full(k, float('inf')), np.full(k, -1)
        
        valid_embeddings = self.embeddings[valid_indices]
        distances_all = np.linalg.norm(valid_embeddings - query[0], axis=1)
        
        if len(distances_all) < k:
            top_k_local = np.arange(len(distances_all))
            distances = np.concatenate([distances_all, np.full(k - len(distances_all), float('inf'))])
            indices = np.concatenate([np.array(valid_indices), np.full(k - len(distances_all), -1)])
        else:
            top_k_local = np.argpartition(distances_all, k)[:k]
            top_k_local = top_k_local[np.argsort(distances_all[top_k_local])]
            distances = distances_all[top_k_local]
            indices = np.array(valid_indices)[top_k_local]
        
        return distances, indices


In [47]:
# 4. Example Usage and Testing
# ============================================

print("\n=== Initializing Hybrid Searcher ===")
hybrid_searcher = HybridSearcher(
    index_ivf, index_flat, embeddings, 
    cluster_to_tracks, metadata_filter
)

# Test query
test_query = embeddings[5000:5001]  # Use a track from middle of dataset
k = 10

# Example filter - create realistic filter based on available data
test_filter = {}

# Add filters based on what's available
if hasattr(metadata_filter, 'numeric_indexes') and 'year' in metadata_filter.numeric_indexes:
    # Use a realistic year range
    test_filter['year'] = (2010, 2020)

if hasattr(metadata_filter, 'categorical_indexes') and 'mode' in metadata_filter.categorical_indexes:
    test_filter['mode'] = [1]  # Major key

# If no filters available, use a simple one
if not test_filter:
    print("Warning: Using minimal filter for testing")
    test_filter = {'mode': [1]} if 'mode' in metadata_filter.categorical_indexes else {}

print(f"\n=== Testing with filter: {test_filter} ===")
print(f"Filter selectivity: {metadata_filter.get_selectivity(test_filter)*100:.2f}%")

# Run all three methods
print("\n1. Hybrid Search (Predicate Pushdown)")
D_hybrid, I_hybrid, stats_hybrid = hybrid_searcher.search_with_predicate_pushdown(
    test_query, k, test_filter, nprobe=20
)
print(f"   Latency: {stats_hybrid['latency']*1000:.2f}ms")
print(f"   Clusters probed: {stats_hybrid['clusters_probed']}")
print(f"   Clusters skipped: {stats_hybrid['clusters_skipped']}")
print(f"   Distance computations: {stats_hybrid['distance_computations']}")

print("\n2. Post-Filter Baseline")
D_post, I_post, stats_post = hybrid_searcher.search_baseline_postfilter(
    test_query, k, test_filter, nprobe=20
)
print(f"   Latency: {stats_post['latency']*1000:.2f}ms")
print(f"   Distance computations: {stats_post['distance_computations']}")
print(f"   Valid results found: {stats_post['valid_found']}")

print("\n3. Pre-Filter Baseline")
D_pre, I_pre, stats_pre = hybrid_searcher.search_baseline_prefilter(
    test_query, k, test_filter
)
print(f"   Latency: {stats_pre['latency']*1000:.2f}ms")
print(f"   Distance computations: {stats_pre['distance_computations']}")

# Ground truth
print("\n4. Ground Truth (Brute Force)")
start = time.time()
D_true, I_true = hybrid_searcher.get_ground_truth(test_query, k, test_filter)
gt_latency = time.time() - start
print(f"   Latency: {gt_latency*1000:.2f}ms")

# Calculate recalls
def calculate_recall(retrieved, ground_truth, k):
    retrieved_set = set(retrieved[:k][retrieved[:k] >= 0])
    gt_set = set(ground_truth[:k][ground_truth[:k] >= 0])
    if len(gt_set) == 0:
        return 0.0
    return len(retrieved_set & gt_set) / len(gt_set)

recall_hybrid = calculate_recall(I_hybrid, I_true, k)
recall_post = calculate_recall(I_post, I_true, k)
recall_pre = calculate_recall(I_pre, I_true, k)

print(f"\n=== Recall@{k} ===")
print(f"Hybrid (Predicate Pushdown): {recall_hybrid:.3f}")
print(f"Post-Filter Baseline: {recall_post:.3f}")
print(f"Pre-Filter Baseline: {recall_pre:.3f}")

print("\n=== Speedup Analysis ===")
print(f"Hybrid vs Post-Filter: {stats_post['latency']/stats_hybrid['latency']:.2f}x faster")
print(f"Hybrid vs Pre-Filter: {stats_pre['latency']/stats_hybrid['latency']:.2f}x")
print(f"Distance computation reduction: {(1 - stats_hybrid['distance_computations']/stats_post['distance_computations'])*100:.1f}%")

# Save the hybrid searcher
with open("data/hybrid_searcher.pkl", "wb") as f:
    pickle.dump(hybrid_searcher, f)

print("\n=== Hybrid Method Implementation Complete ===")
print("Saved hybrid searcher to data/hybrid_searcher.pkl")


=== Initializing Hybrid Searcher ===

=== Testing with filter: {'year': (2010, 2020), 'mode': [1]} ===
Filter selectivity: 8.21%

1. Hybrid Search (Predicate Pushdown)
   Latency: 53.70ms
   Clusters probed: 7
   Clusters skipped: 12
   Distance computations: 64

2. Post-Filter Baseline
   Latency: 32.41ms
   Distance computations: 100
   Valid results found: 0

3. Pre-Filter Baseline
   Latency: 27.63ms
   Distance computations: 13940

4. Ground Truth (Brute Force)
   Latency: 33.25ms

=== Recall@10 ===
Hybrid (Predicate Pushdown): 0.700
Post-Filter Baseline: 0.000
Pre-Filter Baseline: 1.000

=== Speedup Analysis ===
Hybrid vs Post-Filter: 0.60x faster
Hybrid vs Pre-Filter: 0.51x
Distance computation reduction: 36.0%

=== Hybrid Method Implementation Complete ===
Saved hybrid searcher to data/hybrid_searcher.pkl
