# ML LeetCode - Part 3: Data Structures and Efficiency 🚀

This notebook focuses on efficient data structures and algorithms crucial for machine learning at scale. Each problem emphasizes computational efficiency, memory optimization, and scalable solutions.

## 🎯 Learning Objectives
- Master efficient nearest neighbor search algorithms
- Implement streaming algorithms for large datasets
- Optimize memory usage with smart data structures
- Handle high-dimensional data efficiently

## 📊 Difficulty Levels
- 🟢 **Easy**: Basic data structure implementations
- 🟡 **Medium**: Optimized algorithms with trade-offs
- 🔴 **Hard**: Advanced algorithms for production systems

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Optional, Dict, Any, Set
import time
import math
from collections import defaultdict, deque
import heapq
import hashlib
import random

def benchmark_algorithm(func, *args, name="Algorithm", n_runs=5):
    """Benchmark an algorithm's performance."""
    times = []
    for _ in range(n_runs):
        start = time.perf_counter()
        result = func(*args)
        end = time.perf_counter()
        times.append(end - start)
    
    avg_time = np.mean(times)
    print(f"⚡ {name}: {avg_time*1000:.3f}ms (±{np.std(times)*1000:.3f}ms)")
    return result, avg_time

plt.style.use('seaborn-v0_8')
np.random.seed(42)

## Problem 1: KD-Tree for Efficient Nearest Neighbor Search 🔴

**Difficulty**: Hard

**Problem**: Implement a KD-tree data structure for efficient nearest neighbor search in multi-dimensional space.

**Constraints**:
- 1 ≤ n_points ≤ 100,000
- 1 ≤ dimensions ≤ 20
- Support k-nearest neighbors and range queries
- Handle dynamic insertions and deletions

**Example**:
```python
points = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
kdtree = KDTree(points)
nearest = kdtree.query([9, 2], k=2)
# Expected: closest points to [9, 2]
```

In [None]:
class KDNode:
    def __init__(self, point, axis, left=None, right=None):
        self.point = point
        self.axis = axis
        self.left = left
        self.right = right

class KDTree:
    
    def __init__(self, points: List[List[float]] = None):
        self.root = None
        self.dimensions = 0
        if points:
            self.build(points)
    
    def build(self, points: List[List[float]]):
        """
        Build KD-tree from list of points.
        
        Time Complexity: O(n log n)
        Space Complexity: O(n)
        """
        if not points:
            return
        
        points = [np.array(p) for p in points]
        self.dimensions = len(points[0])
        self.root = self._build_recursive(points, depth=0)
    
    def _build_recursive(self, points: List[np.ndarray], depth: int) -> Optional[KDNode]:
        """Recursively build KD-tree."""
        if not points:
            return None
        
        # Select axis based on depth
        axis = depth % self.dimensions
        
        # Sort points by current axis
        points.sort(key=lambda x: x[axis])
        
        # Choose median as pivot
        median = len(points) // 2
        
        # Create node and construct subtrees
        node = KDNode(
            point=points[median],
            axis=axis,
            left=self._build_recursive(points[:median], depth + 1),
            right=self._build_recursive(points[median + 1:], depth + 1)
        )
        
        return node
    
    def query(self, query_point: List[float], k: int = 1) -> List[Tuple[float, List[float]]]:
        """
        Find k nearest neighbors.
        
        Time Complexity: O(log n) average, O(n) worst case
        Space Complexity: O(log n)
        
        Returns: List of (distance, point) tuples
        """
        query_point = np.array(query_point)
        best_neighbors = []  # Max heap of (distance, point)
        
        def add_neighbor(point, distance):
            if len(best_neighbors) < k:
                heapq.heappush(best_neighbors, (-distance, point.tolist()))
            elif distance < -best_neighbors[0][0]:
                heapq.heappop(best_neighbors)
                heapq.heappush(best_neighbors, (-distance, point.tolist()))
        
        def search_recursive(node):
            if node is None:
                return
            
            # Calculate distance to current node
            distance = np.linalg.norm(query_point - node.point)
            add_neighbor(node.point, distance)
            
            # Determine which side to search first
            axis = node.axis
            diff = query_point[axis] - node.point[axis]
            
            # Search near side first
            near_side = node.left if diff <= 0 else node.right
            far_side = node.right if diff <= 0 else node.left
            
            search_recursive(near_side)
            
            # Check if we need to search far side
            if (len(best_neighbors) < k or 
                abs(diff) < -best_neighbors[0][0]):
                search_recursive(far_side)
        
        search_recursive(self.root)
        
        # Convert max heap to sorted list (closest first)
        result = [(-dist, point) for dist, point in best_neighbors]
        result.sort()
        
        return result
    
    def range_query(self, center: List[float], radius: float) -> List[List[float]]:
        """
        Find all points within radius of center.
        
        Time Complexity: O(n) worst case
        Space Complexity: O(n)
        """
        center = np.array(center)
        result = []
        
        def search_recursive(node):
            if node is None:
                return
            
            # Check if current point is in range
            distance = np.linalg.norm(center - node.point)
            if distance <= radius:
                result.append(node.point.tolist())
            
            # Check if we need to search child nodes
            axis = node.axis
            diff = center[axis] - node.point[axis]
            
            if diff <= radius:
                search_recursive(node.left)
            if diff >= -radius:
                search_recursive(node.right)
        
        search_recursive(self.root)
        return result
    
    def insert(self, point: List[float]):
        """
        Insert a new point into the tree.
        Note: This is a simplified insertion that may lead to unbalanced trees.
        """
        point = np.array(point)
        if self.root is None:
            self.dimensions = len(point)
            self.root = KDNode(point, 0)
        else:
            self._insert_recursive(self.root, point, 0)
    
    def _insert_recursive(self, node: KDNode, point: np.ndarray, depth: int):
        """Recursively insert point."""
        axis = depth % self.dimensions
        
        if point[axis] <= node.point[axis]:
            if node.left is None:
                node.left = KDNode(point, (depth + 1) % self.dimensions)
            else:
                self._insert_recursive(node.left, point, depth + 1)
        else:
            if node.right is None:
                node.right = KDNode(point, (depth + 1) % self.dimensions)
            else:
                self._insert_recursive(node.right, point, depth + 1)

# Test KD-Tree implementation
print("=== Problem 1: KD-Tree Implementation ===")

# Test case 1: Small 2D dataset
points_2d = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
kdtree = KDTree(points_2d)

query_point = [9, 2]
nearest = kdtree.query(query_point, k=3)

print(f"\nTest Case 1: 2D Nearest Neighbors")
print(f"Query point: {query_point}")
print(f"3 nearest neighbors:")
for i, (dist, point) in enumerate(nearest, 1):
    print(f"  {i}. {point} (distance: {dist:.3f})")

# Test range query
center = [5, 4]
radius = 3.0
points_in_range = kdtree.range_query(center, radius)

print(f"\nRange Query:")
print(f"Center: {center}, Radius: {radius}")
print(f"Points in range: {points_in_range}")

# Benchmark against brute force
def brute_force_knn(points, query, k):
    """Brute force k-nearest neighbors."""
    distances = [(np.linalg.norm(np.array(p) - np.array(query)), p) for p in points]
    distances.sort()
    return distances[:k]

# Generate larger dataset for benchmarking
np.random.seed(42)
large_points = np.random.randn(1000, 3).tolist()
large_kdtree = KDTree(large_points)
test_query = [0.5, 0.5, 0.5]

print(f"\nPerformance Comparison (1000 points, 3D):")
kd_result, kd_time = benchmark_algorithm(large_kdtree.query, test_query, 5, name="KD-Tree")
bf_result, bf_time = benchmark_algorithm(brute_force_knn, large_points, test_query, 5, name="Brute Force")

print(f"Speedup: {bf_time/kd_time:.1f}x")

# Verify results match
kd_distances = [dist for dist, _ in kd_result]
bf_distances = [dist for dist, _ in bf_result]
matches = np.allclose(kd_distances, bf_distances, rtol=1e-10)
print(f"Results match: {matches}")

## Problem 2: Locality Sensitive Hashing (LSH) 🟡

**Difficulty**: Medium

**Problem**: Implement LSH for approximate nearest neighbor search in high-dimensional spaces.

**Constraints**:
- 1 ≤ n_points ≤ 50,000
- 10 ≤ dimensions ≤ 1000
- Support different hash families (random projections, min-hash)
- Configurable recall vs speed trade-off

**Example**:
```python
points = [[1, 0, 1, 0], [0, 1, 0, 1], [1, 1, 0, 0]]
lsh = LSH(n_hashes=10, n_bands=5)
lsh.fit(points)
candidates = lsh.query([1, 0, 0, 1], max_candidates=10)
```

In [None]:
class LSH:
    
    def __init__(self, n_hashes: int = 10, n_bands: int = 5, hash_family: str = 'random_projection'):
        self.n_hashes = n_hashes
        self.n_bands = n_bands
        self.rows_per_band = n_hashes // n_bands
        self.hash_family = hash_family
        self.hash_functions = []
        self.hash_tables = [defaultdict(list) for _ in range(n_bands)]
        self.points = []
        self.dimensions = 0
    
    def _generate_hash_functions(self, dimensions: int):
        """Generate hash functions based on chosen family."""
        self.hash_functions = []
        
        if self.hash_family == 'random_projection':
            # Random hyperplanes for cosine similarity
            for _ in range(self.n_hashes):
                random_vector = np.random.randn(dimensions)
                self.hash_functions.append(random_vector)
        
        elif self.hash_family == 'p_stable':
            # p-stable hash for Euclidean distance (p=2)
            self.hash_width = 1.0  # Hash bucket width
            for _ in range(self.n_hashes):
                a = np.random.randn(dimensions)
                b = np.random.uniform(0, self.hash_width)
                self.hash_functions.append((a, b))
    
    def _hash_point(self, point: np.ndarray) -> List[int]:
        """Compute hash values for a point."""
        hashes = []
        
        if self.hash_family == 'random_projection':
            for random_vector in self.hash_functions:
                # Hash is 1 if dot product > 0, else 0
                hash_val = 1 if np.dot(point, random_vector) > 0 else 0
                hashes.append(hash_val)
        
        elif self.hash_family == 'p_stable':
            for a, b in self.hash_functions:
                # Hash = floor((a·x + b) / w)
                hash_val = int((np.dot(a, point) + b) / self.hash_width)
                hashes.append(hash_val)
        
        return hashes
    
    def _create_band_hashes(self, hashes: List[int]) -> List[str]:
        """Create band hashes from individual hash values."""
        band_hashes = []
        
        for band in range(self.n_bands):
            start_idx = band * self.rows_per_band
            end_idx = min(start_idx + self.rows_per_band, len(hashes))
            
            # Create composite hash for this band
            band_signature = tuple(hashes[start_idx:end_idx])
            band_hash = str(hash(band_signature))
            band_hashes.append(band_hash)
        
        return band_hashes
    
    def fit(self, points: List[List[float]]):
        """
        Build LSH index.
        
        Time Complexity: O(n * d * h) where h is number of hashes
        Space Complexity: O(n * h)
        """
        self.points = [np.array(p) for p in points]
        self.dimensions = len(self.points[0])
        
        # Generate hash functions
        self._generate_hash_functions(self.dimensions)
        
        # Hash all points and add to tables
        for point_idx, point in enumerate(self.points):
            hashes = self._hash_point(point)
            band_hashes = self._create_band_hashes(hashes)
            
            # Add to each band's hash table
            for band_idx, band_hash in enumerate(band_hashes):
                self.hash_tables[band_idx][band_hash].append(point_idx)
    
    def query(self, query_point: List[float], max_candidates: int = 100) -> List[int]:
        """
        Find candidate nearest neighbors.
        
        Time Complexity: O(d * h + c) where c is number of candidates
        Space Complexity: O(c)
        
        Returns: List of point indices that are candidates
        """
        query_point = np.array(query_point)
        candidates = set()
        
        # Hash query point
        hashes = self._hash_point(query_point)
        band_hashes = self._create_band_hashes(hashes)
        
        # Find candidates from each band
        for band_idx, band_hash in enumerate(band_hashes):
            if band_hash in self.hash_tables[band_idx]:
                candidates.update(self.hash_tables[band_idx][band_hash])
            
            if len(candidates) >= max_candidates:
                break
        
        return list(candidates)[:max_candidates]
    
    def query_with_distances(self, query_point: List[float], k: int = 5, 
                           max_candidates: int = 100) -> List[Tuple[float, int]]:
        """
        Find k nearest neighbors from LSH candidates.
        """
        query_point = np.array(query_point)
        candidates = self.query(query_point, max_candidates)
        
        # Calculate actual distances to candidates
        distances = []
        for candidate_idx in candidates:
            candidate_point = self.points[candidate_idx]
            distance = np.linalg.norm(query_point - candidate_point)
            distances.append((distance, candidate_idx))
        
        # Return k closest
        distances.sort()
        return distances[:k]
    
    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the LSH index."""
        total_buckets = sum(len(table) for table in self.hash_tables)
        max_bucket_size = max(max(len(bucket) for bucket in table.values()) 
                             if table else 0 for table in self.hash_tables)
        avg_bucket_size = np.mean([len(bucket) for table in self.hash_tables 
                                  for bucket in table.values()]) if total_buckets > 0 else 0
        
        return {
            'total_buckets': total_buckets,
            'max_bucket_size': max_bucket_size,
            'avg_bucket_size': avg_bucket_size,
            'n_points': len(self.points),
            'dimensions': self.dimensions
        }

# Test LSH implementation
print("\n=== Problem 2: Locality Sensitive Hashing ===")

# Generate high-dimensional test data
np.random.seed(42)
n_points, n_dims = 1000, 50
points_hd = np.random.randn(n_points, n_dims)

# Add some clustered structure
cluster_centers = np.random.randn(5, n_dims) * 3
for i in range(n_points):
    cluster_id = i % 5
    points_hd[i] += cluster_centers[cluster_id]

points_hd_list = points_hd.tolist()

print(f"\nDataset: {n_points} points in {n_dims} dimensions")

# Test different LSH configurations
configs = [
    {'n_hashes': 20, 'n_bands': 4, 'hash_family': 'random_projection'},
    {'n_hashes': 20, 'n_bands': 10, 'hash_family': 'random_projection'},
    {'n_hashes': 20, 'n_bands': 4, 'hash_family': 'p_stable'}
]

results = {}
query_point = points_hd[0]  # Use first point as query
true_knn = 5

# Brute force for ground truth
def brute_force_knn_hd(points, query, k):
    distances = [(np.linalg.norm(p - query), i) for i, p in enumerate(points)]
    distances.sort()
    return distances[:k]

true_neighbors = brute_force_knn_hd(points_hd, query_point, true_knn)
true_indices = {idx for _, idx in true_neighbors}

print(f"\nTrue {true_knn}-NN indices: {sorted(true_indices)}")

for i, config in enumerate(configs):
    print(f"\nConfiguration {i+1}: {config}")
    
    # Build LSH index
    lsh = LSH(**config)
    build_time = time.perf_counter()
    lsh.fit(points_hd_list)
    build_time = time.perf_counter() - build_time
    
    # Query
    query_time = time.perf_counter()
    lsh_neighbors = lsh.query_with_distances(query_point.tolist(), k=true_knn, max_candidates=200)
    query_time = time.perf_counter() - query_time
    
    lsh_indices = {idx for _, idx in lsh_neighbors}
    
    # Calculate recall
    recall = len(true_indices.intersection(lsh_indices)) / len(true_indices)
    
    # Get LSH stats
    stats = lsh.get_stats()
    
    results[i] = {
        'config': config,
        'recall': recall,
        'build_time': build_time,
        'query_time': query_time,
        'stats': stats
    }
    
    print(f"  Recall: {recall:.3f}")
    print(f"  Build time: {build_time*1000:.2f}ms")
    print(f"  Query time: {query_time*1000:.4f}ms")
    print(f"  Avg bucket size: {stats['avg_bucket_size']:.1f}")
    print(f"  LSH neighbors: {sorted(lsh_indices)}")

# Visualize recall vs speed trade-off
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Recall vs Query Time
recalls = [results[i]['recall'] for i in range(len(configs))]
query_times = [results[i]['query_time'] * 1000 for i in range(len(configs))]
config_labels = [f"Config {i+1}" for i in range(len(configs))]

ax1.scatter(query_times, recalls, s=100, alpha=0.7)
for i, label in enumerate(config_labels):
    ax1.annotate(label, (query_times[i], recalls[i]), 
                xytext=(5, 5), textcoords='offset points')

ax1.set_xlabel('Query Time (ms)')
ax1.set_ylabel('Recall')
ax1.set_title('LSH Recall vs Speed Trade-off')
ax1.grid(True, alpha=0.3)

# Bucket size comparison
bucket_sizes = [results[i]['stats']['avg_bucket_size'] for i in range(len(configs))]
ax2.bar(config_labels, bucket_sizes, alpha=0.7)
ax2.set_ylabel('Average Bucket Size')
ax2.set_title('Hash Table Bucket Sizes')
ax2.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

print("\n📊 LSH Analysis:")
print("• More bands → Higher recall but slower queries")
print("• Random projection good for cosine similarity")
print("• p-stable hashing good for Euclidean distance")
print("• Trade-off between recall and speed is configurable")

## Problem 3: Count-Min Sketch for Streaming Data 🟡

**Difficulty**: Medium

**Problem**: Implement Count-Min Sketch for approximate frequency counting in data streams.

**Constraints**:
- Memory usage: O(ε⁻¹ log δ⁻¹)
- Error guarantee: |count(x) - count*(x)| ≤ ε·N with probability 1-δ
- Support for point queries and heavy hitter detection
- Handle arbitrary data types as items

**Example**:
```python
cms = CountMinSketch(epsilon=0.01, delta=0.01)
for item in stream:
    cms.add(item)
estimated_count = cms.query("frequent_item")
heavy_hitters = cms.get_heavy_hitters(threshold=100)
```

In [None]:
import mmh3  # MurmurHash3 for better hash distribution

class CountMinSketch:
    
    def __init__(self, epsilon: float = 0.01, delta: float = 0.01):
        """
        Initialize Count-Min Sketch.
        
        Args:
            epsilon: Relative error (smaller = more accurate)
            delta: Failure probability (smaller = more reliable)
        """
        self.epsilon = epsilon
        self.delta = delta
        
        # Calculate dimensions based on error parameters
        self.width = int(math.ceil(math.e / epsilon))  # Number of buckets per hash
        self.height = int(math.ceil(math.log(1 / delta)))  # Number of hash functions
        
        # Initialize count matrix
        self.counts = np.zeros((self.height, self.width), dtype=np.int64)
        
        # Generate hash function seeds
        random.seed(42)  # For reproducibility
        self.hash_seeds = [random.randint(0, 2**31 - 1) for _ in range(self.height)]
        
        # Track total number of items
        self.total_count = 0
        
        print(f"Count-Min Sketch: {self.height}x{self.width} matrix")
        print(f"Memory usage: {self.height * self.width * 8} bytes")
    
    def _hash(self, item: Any, seed: int) -> int:
        """Hash function for arbitrary items."""
        # Convert item to string for hashing
        item_str = str(item).encode('utf-8')
        
        # Use MurmurHash3 for good distribution
        try:
            hash_val = mmh3.hash(item_str, seed, signed=False)
        except:
            # Fallback to built-in hash if mmh3 not available
            hash_val = hash((item_str, seed)) % (2**32)
        
        return hash_val % self.width
    
    def add(self, item: Any, count: int = 1):
        """
        Add item to the sketch.
        
        Time Complexity: O(height)
        Space Complexity: O(1)
        """
        for i in range(self.height):
            j = self._hash(item, self.hash_seeds[i])
            self.counts[i, j] += count
        
        self.total_count += count
    
    def query(self, item: Any) -> int:
        """
        Estimate count of item.
        
        Time Complexity: O(height)
        Space Complexity: O(1)
        
        Returns: Estimated count (upper bound)
        """
        min_count = float('inf')
        
        for i in range(self.height):
            j = self._hash(item, self.hash_seeds[i])
            min_count = min(min_count, self.counts[i, j])
        
        return int(min_count)
    
    def get_heavy_hitters(self, threshold: int) -> List[Tuple[str, int]]:
        """
        Find potential heavy hitters above threshold.
        Note: This is a simplified version that checks bucket values.
        """
        candidates = set()
        
        # Find buckets with counts above threshold
        for i in range(self.height):
            for j in range(self.width):
                if self.counts[i, j] >= threshold:
                    # This bucket might contain heavy hitters
                    candidates.add((i, j, self.counts[i, j]))
        
        # Return unique candidate buckets
        heavy_hitters = [(f"bucket_{i}_{j}", count) for i, j, count in candidates]
        heavy_hitters.sort(key=lambda x: x[1], reverse=True)
        
        return heavy_hitters
    
    def merge(self, other: 'CountMinSketch') -> 'CountMinSketch':
        """
        Merge two Count-Min Sketches.
        Both sketches must have same dimensions.
        """
        if (self.height != other.height or 
            self.width != other.width or 
            self.hash_seeds != other.hash_seeds):
            raise ValueError("Sketches must have same parameters for merging")
        
        merged = CountMinSketch(self.epsilon, self.delta)
        merged.height = self.height
        merged.width = self.width
        merged.hash_seeds = self.hash_seeds
        merged.counts = self.counts + other.counts
        merged.total_count = self.total_count + other.total_count
        
        return merged
    
    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the sketch."""
        return {
            'dimensions': (self.height, self.width),
            'total_count': self.total_count,
            'memory_bytes': self.height * self.width * 8,
            'epsilon': self.epsilon,
            'delta': self.delta,
            'max_bucket_count': int(np.max(self.counts)),
            'avg_bucket_count': float(np.mean(self.counts))
        }

# Test Count-Min Sketch
print("\n=== Problem 3: Count-Min Sketch ===")

# Simulate a data stream with Zipfian distribution
def generate_zipfian_stream(n_items: int, n_unique: int, alpha: float = 1.0) -> List[str]:
    """Generate stream with Zipfian frequency distribution."""
    # Generate frequencies according to Zipf's law
    frequencies = np.array([1/i**alpha for i in range(1, n_unique + 1)])
    frequencies = frequencies / frequencies.sum()
    
    # Generate items
    items = [f"item_{i}" for i in range(n_unique)]
    stream = np.random.choice(items, size=n_items, p=frequencies)
    
    return stream.tolist()

# Generate test stream
np.random.seed(42)
stream_size = 10000
unique_items = 1000
stream = generate_zipfian_stream(stream_size, unique_items, alpha=1.2)

print(f"\nStream: {stream_size} items, {unique_items} unique")

# True counts for comparison
true_counts = Counter(stream)
most_frequent = true_counts.most_common(10)

print("\nTrue top 10 items:")
for item, count in most_frequent:
    print(f"  {item}: {count}")

# Test different sketch configurations
configs = [
    {'epsilon': 0.01, 'delta': 0.01},  # High accuracy
    {'epsilon': 0.05, 'delta': 0.05},  # Medium accuracy
    {'epsilon': 0.1, 'delta': 0.1}    # Lower accuracy, less memory
]

sketch_results = []

for i, config in enumerate(configs):
    print(f"\nConfiguration {i+1}: ε={config['epsilon']}, δ={config['delta']}")
    
    # Build sketch
    cms = CountMinSketch(**config)
    
    # Process stream
    start_time = time.perf_counter()
    for item in stream:
        cms.add(item)
    processing_time = time.perf_counter() - start_time
    
    # Test queries on most frequent items
    errors = []
    for item, true_count in most_frequent:
        estimated_count = cms.query(item)
        error = abs(estimated_count - true_count) / true_count
        errors.append(error)
    
    avg_error = np.mean(errors)
    max_error = np.max(errors)
    
    stats = cms.get_stats()
    
    sketch_results.append({
        'config': config,
        'avg_error': avg_error,
        'max_error': max_error,
        'processing_time': processing_time,
        'memory_bytes': stats['memory_bytes']
    })
    
    print(f"  Processing time: {processing_time*1000:.2f}ms")
    print(f"  Memory usage: {stats['memory_bytes']} bytes")
    print(f"  Average error: {avg_error:.3f}")
    print(f"  Maximum error: {max_error:.3f}")
    
    # Show estimates for top items
    print(f"  Estimates for top 5 items:")
    for item, true_count in most_frequent[:5]:
        estimated = cms.query(item)
        error_pct = (estimated - true_count) / true_count * 100
        print(f"    {item}: {estimated} (true: {true_count}, error: {error_pct:+.1f}%)")

# Visualize accuracy vs memory trade-off
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Error vs Memory
memory_usage = [r['memory_bytes'] for r in sketch_results]
avg_errors = [r['avg_error'] for r in sketch_results]
config_labels = [f"ε={r['config']['epsilon']}" for r in sketch_results]

ax1.scatter(memory_usage, avg_errors, s=100, alpha=0.7)
for i, label in enumerate(config_labels):
    ax1.annotate(label, (memory_usage[i], avg_errors[i]), 
                xytext=(5, 5), textcoords='offset points')

ax1.set_xlabel('Memory Usage (bytes)')
ax1.set_ylabel('Average Relative Error')
ax1.set_title('Count-Min Sketch: Accuracy vs Memory')
ax1.grid(True, alpha=0.3)

# Processing speed
processing_times = [r['processing_time'] * 1000 for r in sketch_results]
ax2.bar(config_labels, processing_times, alpha=0.7)
ax2.set_ylabel('Processing Time (ms)')
ax2.set_title('Stream Processing Speed')
ax2.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

print("\n📊 Count-Min Sketch Analysis:")
print("• Probabilistic data structure for frequency estimation")
print("• Always overestimates (provides upper bound)")
print("• Error decreases with smaller ε (more memory)")
print("• Excellent for finding heavy hitters in streams")
print("• Mergeable for distributed computing")

## Problem 4: Bloom Filter for Set Membership 🟢

**Difficulty**: Easy

**Problem**: Implement a Bloom filter for space-efficient approximate set membership testing.

**Constraints**:
- Support configurable false positive rate
- Memory usage: O(m) where m is optimal based on n and desired FP rate
- No false negatives allowed
- Support union operation

**Example**:
```python
bf = BloomFilter(capacity=1000, error_rate=0.01)
bf.add("item1")
bf.add("item2")
if "item1" in bf:  # Definitely in set
    print("Found")
if "item3" in bf:  # Might be false positive
    print("Maybe found")
```

In [None]:
class BloomFilter:
    
    def __init__(self, capacity: int, error_rate: float = 0.01):
        """
        Initialize Bloom Filter.
        
        Args:
            capacity: Expected number of items
            error_rate: Desired false positive rate
        """
        self.capacity = capacity
        self.error_rate = error_rate
        
        # Calculate optimal parameters
        self.size = self._optimal_size(capacity, error_rate)
        self.hash_count = self._optimal_hash_count(self.size, capacity)
        
        # Initialize bit array
        self.bit_array = np.zeros(self.size, dtype=bool)
        
        # Generate hash function seeds
        random.seed(42)
        self.hash_seeds = [random.randint(0, 2**31 - 1) for _ in range(self.hash_count)]
        
        # Track number of items added
        self.items_added = 0
        
        print(f"Bloom Filter: {self.size} bits, {self.hash_count} hash functions")
        print(f"Expected FP rate: {self.error_rate:.4f}")
    
    def _optimal_size(self, n: int, p: float) -> int:
        """Calculate optimal bit array size."""
        m = -(n * math.log(p)) / (math.log(2) ** 2)
        return int(math.ceil(m))
    
    def _optimal_hash_count(self, m: int, n: int) -> int:
        """Calculate optimal number of hash functions."""
        k = (m / n) * math.log(2)
        return int(math.ceil(k))
    
    def _hash(self, item: Any, seed: int) -> int:
        """Hash function for arbitrary items."""
        item_str = str(item).encode('utf-8')
        
        try:
            hash_val = mmh3.hash(item_str, seed, signed=False)
        except:
            hash_val = hash((item_str, seed)) % (2**32)
        
        return hash_val % self.size
    
    def add(self, item: Any):
        """
        Add item to the filter.
        
        Time Complexity: O(k) where k is number of hash functions
        Space Complexity: O(1)
        """
        for seed in self.hash_seeds:
            index = self._hash(item, seed)
            self.bit_array[index] = True
        
        self.items_added += 1
    
    def __contains__(self, item: Any) -> bool:
        """
        Test membership (might have false positives).
        
        Time Complexity: O(k)
        Space Complexity: O(1)
        """
        for seed in self.hash_seeds:
            index = self._hash(item, seed)
            if not self.bit_array[index]:
                return False
        return True
    
    def union(self, other: 'BloomFilter') -> 'BloomFilter':
        """
        Union of two Bloom filters.
        Both filters must have same parameters.
        """
        if (self.size != other.size or 
            self.hash_count != other.hash_count or 
            self.hash_seeds != other.hash_seeds):
            raise ValueError("Filters must have same parameters for union")
        
        result = BloomFilter(self.capacity, self.error_rate)
        result.size = self.size
        result.hash_count = self.hash_count
        result.hash_seeds = self.hash_seeds
        result.bit_array = self.bit_array | other.bit_array
        result.items_added = self.items_added + other.items_added
        
        return result
    
    def false_positive_rate(self) -> float:
        """
        Calculate current false positive rate.
        """
        if self.items_added == 0:
            return 0.0
        
        # Fraction of bits that are set
        fraction_set = np.sum(self.bit_array) / self.size
        
        # Estimated FP rate
        fp_rate = fraction_set ** self.hash_count
        return fp_rate
    
    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the filter."""
        bits_set = int(np.sum(self.bit_array))
        load_factor = bits_set / self.size
        
        return {
            'size': self.size,
            'hash_count': self.hash_count,
            'items_added': self.items_added,
            'bits_set': bits_set,
            'load_factor': load_factor,
            'expected_fp_rate': self.error_rate,
            'actual_fp_rate': self.false_positive_rate(),
            'memory_bytes': self.size // 8
        }

# Test Bloom Filter
print("\n=== Problem 4: Bloom Filter ===")

# Test different configurations
test_items = [f"item_{i}" for i in range(1000)]
non_existent_items = [f"missing_{i}" for i in range(500)]

configs = [
    {'capacity': 1000, 'error_rate': 0.001},  # Very low FP rate
    {'capacity': 1000, 'error_rate': 0.01},   # Standard FP rate
    {'capacity': 1000, 'error_rate': 0.1}     # High FP rate, less memory
]

bloom_results = []

for i, config in enumerate(configs):
    print(f"\nConfiguration {i+1}: capacity={config['capacity']}, error_rate={config['error_rate']}")
    
    # Create and populate filter
    bf = BloomFilter(**config)
    
    start_time = time.perf_counter()
    for item in test_items:
        bf.add(item)
    add_time = time.perf_counter() - start_time
    
    # Test membership for existing items (should all be True)
    start_time = time.perf_counter()
    existing_results = [item in bf for item in test_items[:100]]
    existing_time = time.perf_counter() - start_time
    
    # Test membership for non-existent items (count false positives)
    start_time = time.perf_counter()
    fp_results = [item in bf for item in non_existent_items]
    query_time = time.perf_counter() - start_time
    
    false_positives = sum(fp_results)
    fp_rate = false_positives / len(non_existent_items)
    
    stats = bf.get_stats()
    
    bloom_results.append({
        'config': config,
        'measured_fp_rate': fp_rate,
        'expected_fp_rate': config['error_rate'],
        'add_time': add_time,
        'query_time': query_time,
        'memory_bytes': stats['memory_bytes'],
        'load_factor': stats['load_factor']
    })
    
    print(f"  Memory usage: {stats['memory_bytes']} bytes")
    print(f"  Add time: {add_time*1000:.2f}ms for {len(test_items)} items")
    print(f"  Query time: {query_time*1000:.3f}ms for {len(non_existent_items)} items")
    print(f"  False positives: {false_positives}/{len(non_existent_items)} ({fp_rate:.4f})")
    print(f"  Expected FP rate: {config['error_rate']:.4f}")
    print(f"  Load factor: {stats['load_factor']:.3f}")
    
    # Verify no false negatives
    false_negatives = sum(1 for item in test_items[:100] if item not in bf)
    print(f"  False negatives: {false_negatives} (should be 0)")

# Test union operation
print(f"\nTesting Union Operation:")
bf1 = BloomFilter(capacity=500, error_rate=0.01)
bf2 = BloomFilter(capacity=500, error_rate=0.01)

# Add different items to each filter
items1 = [f"set1_item_{i}" for i in range(250)]
items2 = [f"set2_item_{i}" for i in range(250)]

for item in items1:
    bf1.add(item)
for item in items2:
    bf2.add(item)

# Union filters
bf_union = bf1.union(bf2)

# Test that union contains items from both sets
items1_in_union = sum(1 for item in items1[:50] if item in bf_union)
items2_in_union = sum(1 for item in items2[:50] if item in bf_union)

print(f"Items from set1 found in union: {items1_in_union}/50")
print(f"Items from set2 found in union: {items2_in_union}/50")

# Visualize results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# FP Rate Comparison
expected_rates = [r['expected_fp_rate'] for r in bloom_results]
measured_rates = [r['measured_fp_rate'] for r in bloom_results]
config_labels = [f"ε={r['config']['error_rate']}" for r in bloom_results]

x = np.arange(len(config_labels))
width = 0.35

ax1.bar(x - width/2, expected_rates, width, label='Expected', alpha=0.7)
ax1.bar(x + width/2, measured_rates, width, label='Measured', alpha=0.7)

ax1.set_xlabel('Configuration')
ax1.set_ylabel('False Positive Rate')
ax1.set_title('Bloom Filter: Expected vs Measured FP Rate')
ax1.set_xticks(x)
ax1.set_xticklabels(config_labels)
ax1.legend()
ax1.set_yscale('log')

# Memory vs Accuracy
memory_usage = [r['memory_bytes'] for r in bloom_results]
ax2.scatter(memory_usage, measured_rates, s=100, alpha=0.7)
for i, label in enumerate(config_labels):
    ax2.annotate(label, (memory_usage[i], measured_rates[i]), 
                xytext=(5, 5), textcoords='offset points')

ax2.set_xlabel('Memory Usage (bytes)')
ax2.set_ylabel('False Positive Rate')
ax2.set_title('Memory vs Accuracy Trade-off')
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n📊 Bloom Filter Analysis:")
print("• No false negatives, only false positives")
print("• Memory usage is independent of item size")
print("• Lower error rate requires more memory")
print("• Excellent for cache filtering and duplicate detection")
print("• Supports set operations (union, but not intersection)")

## Summary and Performance Analysis 📈

### 🏆 Data Structures Implemented:

1. **KD-Tree** 🔴
   - Efficient nearest neighbor search in low-medium dimensions
   - **Key Insight**: Performance degrades in high dimensions (curse of dimensionality)
   - **Best Use**: 2D-10D spatial data, exact nearest neighbors

2. **Locality Sensitive Hashing (LSH)** 🟡
   - Approximate nearest neighbors in high dimensions
   - **Key Insight**: Trade recall for speed, configurable parameters
   - **Best Use**: High-dimensional data, approximate similarity search

3. **Count-Min Sketch** 🟡
   - Frequency estimation in data streams
   - **Key Insight**: Always overestimates, tunable accuracy vs memory
   - **Best Use**: Heavy hitter detection, streaming analytics

4. **Bloom Filter** 🟢
   - Space-efficient set membership testing
   - **Key Insight**: No false negatives, controllable false positive rate
   - **Best Use**: Cache filtering, duplicate detection

### ⚡ Complexity Analysis:

| Data Structure | Construction | Query | Space | Best Dimension |
|----------------|-------------|-------|-------|----------------|
| **KD-Tree** | O(n log n) | O(log n) avg | O(n) | Low (≤10) |
| **LSH** | O(nhd) | O(h + c) | O(nh) | High (≥50) |
| **Count-Min** | O(1) | O(d) | O(1/ε log 1/δ) | Streaming |
| **Bloom Filter** | O(1) | O(k) | O(m) | Set membership |

### 🎯 Key Implementation Insights:

1. **Dimension Sensitivity**: KD-trees excel in low dimensions, LSH in high dimensions
2. **Accuracy Trade-offs**: Probabilistic structures offer configurable accuracy
3. **Memory Efficiency**: Sketches and filters use sub-linear space
4. **Streaming Capability**: Count-Min Sketch handles infinite streams

### 🚀 Production Considerations:

1. **Choose the Right Tool**:
   - KD-Tree: Exact NN in ≤10D
   - LSH: Approximate NN in high-D
   - Count-Min: Frequency queries
   - Bloom Filter: Membership tests

2. **Parameter Tuning**:
   - Balance accuracy vs memory vs speed
   - Consider data distribution and query patterns
   - Profile with realistic workloads

3. **Scalability**:
   - Distributed versions available for most structures
   - Consider mergeable properties for parallel processing
   - Plan for incremental updates vs rebuilds

### 📚 Next Steps:
- Implement advanced variants (Ball trees, Random projection trees)
- Study distributed versions (Cassandra's Bloom filters, Redis HyperLogLog)
- Explore learned indices and ML-optimized data structures
- Practice with real-world datasets and benchmarks