#### Imports

In [3]:
import hnswlib
import numpy as np
import time
from collections import defaultdict
from sklearn.cluster import KMeans
from collections import Counter


In [7]:
class LabelFilteredANNEvaluator:
    """
    Design Metrics for Filtered ANN Search:
    1. Query Latency:
       - Measures search time with/without filters
       - Compares overhead of filtering
    2. Accuracy Impact:
       - Recall@k: proportion of true nearest neighbors found
       - How filtering affects quality of results
    3. Filter Friction:
       - Filter specificity: proportion of points passing filter
       - Impact of label distribution on performance
    """
    def __init__(self, dim=16, num_elements=3000):
        self.dim = dim
        self.num_elements = num_elements
        self.metrics = defaultdict(list)

    def generate_skewed_labeled_data(self):
        """Generate skewed data with three labels distributed as 60%, 30%, 10%"""
        self.data = np.float32(np.random.random((self.num_elements, self.dim)))
    
        label_a_count = int(self.num_elements * 0.6)  # 60%
        label_b_count = int(self.num_elements * 0.3)  # 30%
        label_c_count = self.num_elements - label_a_count - label_b_count  # Remaining (10%)
    
        self.labels = np.array(['a'] * label_a_count + 
                          ['b'] * label_b_count + 
                          ['c'] * label_c_count)
    
        p = np.random.permutation(len(self.data))
        self.data = self.data[p]
        self.labels = self.labels[p]
    
        unique, counts = np.unique(self.labels, return_counts=True)
        self.metrics['label_distribution'] = dict(zip(unique, counts / len(self.labels)))
    
        return self.data, self.labels
 
    def build_index(self):
        """Build HNSW index with the generated data"""
        self.index = hnswlib.Index(space='cosine', dim=self.dim)
        self.index.init_index(max_elements=self.num_elements, ef_construction=100, M=16)
        self.index.set_ef(20)
        self.index.set_num_threads(1)
        start_time = time.time()
        self.index.add_items(self.data, ids=np.arange(self.num_elements))
        build_time = time.time() - start_time
        self.metrics['build_time'] = build_time

    def create_label_filter(self, target_label):
        """Create filter function for a specific label"""
        def filter_function(idx):
            return self.labels[idx] == target_label
        return filter_function
    
    def calculate_recall(self, filtered_results, true_results, query_points, target_label, k):
        """
    Calculate recall@k for filtered nearest neighbor search results.
    
    Args:
        filtered_results: Results from filtered knn search (n_queries x k)
        true_results: Results from unfiltered knn search (n_queries x k)
        query_points: Query points used for search (n_queries x dim)
        target_label: Label to filter for
        k: Number of nearest neighbors
    
    Returns:
        float: Average recall@k across all queries
        """
        recall = 0
        n_queries = len(query_points)
    
        target_mask = self.labels == target_label
        target_data = self.data[target_mask]
        target_indices = np.where(target_mask)[0]
    
        for i in range(n_queries):
            distances = np.linalg.norm(target_data - query_points[i], axis=1)
            true_neighbor_indices = target_indices[np.argsort(distances)[:k]]
        
            filtered_neighbor_indices = filtered_results[i]
        
            intersection = set(filtered_neighbor_indices) & set(true_neighbor_indices)
            recall += len(intersection) / k
    
        return recall / n_queries

    def generate_distribution_aware_queries(self, num_queries=100):
        """
        Generate random queries with varying filter specificities and attribute distributions
        """
        queries = {}
        
        # High filter specificity (targeting common attributes)
        label_counts = Counter(self.labels)
        common_label = max(label_counts, key=label_counts.get)
        queries['high_specificity'] = np.float32(
            self.data[self.labels == common_label][
                np.random.choice(np.sum(self.labels == common_label), num_queries)
            ]
        )
        
        # Low filter specificity (targeting rare attributes)
        rare_label = min(label_counts, key=label_counts.get)
        queries['low_specificity'] = np.float32(
            self.data[self.labels == rare_label][
                np.random.choice(np.sum(self.labels == rare_label), num_queries)
            ]
        )
        
        # Balanced distribution queries
        balanced_indices = []
        for label in np.unique(self.labels):
            label_indices = np.where(self.labels == label)[0]
            balanced_indices.extend(
                np.random.choice(label_indices, num_queries // len(np.unique(self.labels)))
            )
        queries['balanced'] = np.float32(self.data[balanced_indices])
        
        return queries

    def evaluate_query_performance(self, num_queries=100, k=10):
        """Evaluate query performance with comprehensive metrics including distribution awareness"""
        
        # Generate distribution-aware queries
        query_sets = self.generate_distribution_aware_queries(num_queries)
        all_metrics = {}
        
        for query_type, query_points in query_sets.items():
            metrics = defaultdict(dict)
            
            # Unfiltered baseline
            start_time = time.time()
            unfiltered_labels, unfiltered_distances = self.index.knn_query(query_points, k=k, num_threads=1)
            unfiltered_time = time.time() - start_time
            
            filter_times = {}
            recall_scores = {}
            filter_specificity = {}
            
            # Per each label metrics
            for label in ['a', 'b', 'c']:
                filter_func = self.create_label_filter(label)
                
                # Latency
                start_time = time.time()
                filtered_labels, filtered_distances = self.index.knn_query(
                    query_points, k=k, num_threads=1, filter=filter_func
                )
                filter_time = time.time() - start_time
                filter_times[label] = filter_time / num_queries
                
                # Filter Specificity
                points_passing_filter = sum(filter_func(i) for i in range(self.num_elements))
                filter_specificity[label] = points_passing_filter / self.num_elements
                
                # Accuracy Impact (Recall)
                recall_scores[label] = self.calculate_recall(
                    filtered_labels,
                    unfiltered_labels,
                    query_points,
                    label,
                    k
                )
            
            metrics['query_latency'] = {
                'unfiltered': unfiltered_time / num_queries,
                'filtered': filter_times
            }
            metrics['filter_specificity'] = filter_specificity
            metrics['recall_scores'] = recall_scores
            metrics['filter_friction'] = {
                'latency_overhead': {label: filter_times[label]/metrics['query_latency']['unfiltered'] 
                            for label in filter_times},
                'specificity': filter_specificity,
                'recall_impact': recall_scores
            }
            
            all_metrics[query_type] = metrics
        
        self.metrics['distribution_aware'] = all_metrics
        return all_metrics

In [9]:
def print_evaluation_results(metrics):
    print("\nDistribution-Aware Evaluation Results:")
    
    # Print metrics for each query type
    for query_type, query_metrics in metrics.items():
        print(f"\n=== {query_type} Queries ===")
        
        print(f"\nQuery Latency:")
        print(f"- Unfiltered: {query_metrics['query_latency']['unfiltered']*1000:.2f} ms per query")
        for label, latency in query_metrics['query_latency']['filtered'].items():
            print(f"- Filtered (label {label}): {latency*1000:.2f} ms per query")
        
        print(f"\nFilter Specificity:")
        for label, specificity in query_metrics['filter_specificity'].items():
            print(f"- Label {label}: {specificity*100:.1f}%")
        
        print("\nRecall Scores:")
        for label, recall in query_metrics['recall_scores'].items():
            print(f"- Label {label}: {recall*100:.1f}%")
        
        print("\nFilter Friction:")
        print("Latency Overhead:")
        for label, overhead in query_metrics['filter_friction']['latency_overhead'].items():
            print(f"- Label {label}: {overhead:.2f}x")
        print("Recall Impact:")
        for label, impact in query_metrics['filter_friction']['recall_impact'].items():
            print(f"- Label {label}: {impact*100:.1f}%")

def run_labeled_evaluation():
    evaluator = LabelFilteredANNEvaluator()
    
    print("Generating labeled skewed data...")
    evaluator.generate_skewed_labeled_data()
    
    print("\nLabel Distribution:")
    for label, freq in evaluator.metrics['label_distribution'].items():
        print(f"- Label {label}: {freq*100:.1f}%")
    
    print("\nBuilding index...")
    evaluator.build_index()
    print(f"Build Time: {evaluator.metrics['build_time']:.3f} seconds")
    
    print("\nEvaluating performance...")
    metrics = evaluator.evaluate_query_performance()
    
    print_evaluation_results(metrics)

if __name__ == "__main__":
    run_labeled_evaluation()

Generating labeled skewed data...

Label Distribution:
- Label a: 60.0%
- Label b: 30.0%
- Label c: 10.0%

Building index...
Build Time: 0.748 seconds

Evaluating performance...

Distribution-Aware Evaluation Results:

=== high_specificity Queries ===

Query Latency:
- Unfiltered: 0.13 ms per query
- Filtered (label a): 0.39 ms per query
- Filtered (label b): 0.62 ms per query
- Filtered (label c): 1.85 ms per query

Filter Specificity:
- Label a: 60.0%
- Label b: 30.0%
- Label c: 10.0%

Recall Scores:
- Label a: 74.9%
- Label b: 71.3%
- Label c: 72.8%

Filter Friction:
Latency Overhead:
- Label a: 2.89x
- Label b: 4.61x
- Label c: 13.71x
Recall Impact:
- Label a: 74.9%
- Label b: 71.3%
- Label c: 72.8%

=== low_specificity Queries ===

Query Latency:
- Unfiltered: 0.12 ms per query
- Filtered (label a): 0.42 ms per query
- Filtered (label b): 0.82 ms per query
- Filtered (label c): 1.79 ms per query

Filter Specificity:
- Label a: 60.0%
- Label b: 30.0%
- Label c: 10.0%

Recall Scores

#### Query Types in Skewed Distribution:
Skewed Distribution (60%, 30%, 10%)

High Specificity: 100 queries from most common label (a) (Queries targeting common labels)

Low Specificity: 100 queries from rarest label (c) (Queries targeting rare labels)

Balanced: ~33 queries from each label (a,b,c) (Queries evenly distributed across all labels)

For uniform Distribution (33%, 33%, 33%):

Can use random queries since all labels have equal distribution
No meaningful distinction between high/low specificity
Balanced sampling happens naturally (as in the previous notebook)

Note: All queries are random within their category, but sampling is controlled based on label frequencies in skewed case.






#### Results


Query Types Impact:


high_specificity: queries from label 'a' (most common)
low_specificity: queries from label 'c' (rarest)
balanced: equal mix of all labels


Key Observations:


Latency overhead increases as label rarity increases (c > b > a)
Label 'c' (10%) has highest latency overhead (5-14x slower)
Label 'a' (60%) has lowest latency overhead (2-3x slower)
Recall stays fairly consistent (~70-75%) across all scenarios


Query Workload Implications:


Filtering for rare labels (c) is most expensive
Balanced query workload shows best overall performance
High-specificity queries (from common label 'a') don't necessarily give better recall

The main takeaway is that query performance significantly depends on label distribution - searching for rare labels takes longer but doesn't affect accuracy much.


The pattern shows: as specificity decreases (60% → 30% → 10%), query latency increases. This makes sense because finding k neighbors in a smaller subset of points requires checking more total points to find enough valid ones. However, balanced queries show lower overall latencies, suggesting query source distribution affects performance.

Another pattern: For each label (regardless of its specificity), balanced queries show significantly better latency. The source of the query (high vs low specificity) doesn't impact latency as much as the query distribution type. Specificity still matters - rarer labels (c) take longer across all query types.