In [None]:
import torch
import matplotlib.pyplot as plt
from collections import defaultdict
import random

class LSH:
    def __init__(self, num_bits, dim):
        """
        Initialize LSH with random hyperplanes.
        :param num_bits: Number of hash functions (i.e., hyperplanes)
        :param dim: Dimension of input vectors
        """
        self.num_bits = num_bits
        self.dim = dim
        self.hyperplanes = torch.randn(num_bits, dim)  # Random hyperplanes
        self.hash_table = defaultdict(list)
    
    def _hash(self, vector):
        """Compute the hash value for a given vector."""
        return tuple((torch.matmul(self.hyperplanes, vector) > 0).int().tolist())
    
    def insert(self, vector, identifier):
        """Insert a vector with an identifier into the hash table."""
        hash_value = self._hash(vector)
        self.hash_table[hash_value].append((identifier, vector))
    
    def query(self, vector, k):
        """Find k approximate nearest neighbors for a given vector."""
        hash_value = self._hash(vector)
        candidates = set()
        
        # Get nearest neighbors from hash table
        sorted_buckets = sorted(self.hash_table.keys(), key=lambda h: sum(a != b for a, b in zip(h, hash_value)))
        for bucket in sorted_buckets:
            if len(candidates) >= k:
                break
            candidates.update(self.hash_table[bucket])
        
        # If we have more than k, randomly select k
        return random.sample(candidates, k) if len(candidates) > k else list(candidates)

    def plot_bucket_distribution(self):
        """Plot the frequency of samples in each bucket."""
        bucket_labels = list(self.hash_table.keys())
        bucket_sizes = [len(self.hash_table[b]) for b in bucket_labels]
        
        plt.figure(figsize=(12, 6))
        plt.bar(range(len(bucket_labels)), bucket_sizes, tick_label=[str(b) for b in bucket_labels])
        plt.xlabel("Bucket Hash Values")
        plt.ylabel("Number of Samples")
        plt.title("Frequency of Samples in Each Bucket")
        plt.xticks(rotation=90)
        plt.show()

# Example usage
dim = 128  # Dimension of vectors
num_bits = 10  # Number of hash functions
lsh = LSH(num_bits, dim)

# Load train embeddings from file
train_embeddings = torch.load("./srcFiles/train_embeddings.pth")
for i, vec in enumerate(train_embeddings):
    lsh.insert(vec, i)

# Query with a new vector
# query_vector = torch.randn(dim)
# neighbors = lsh.query(query_vector, k=5)
# print("Approximate Neighbors:", neighbors)

# Plot bucket distribution
lsh.plot_bucket_distribution()

FileNotFoundError: [Errno 2] No such file or directory: 'train_embeddings.pth'