In [1]:
## quick tutorial on Approximate Nearest Neighbors
## tutorial url:
## https://towardsdatascience.com/comprehensive-guide-to-approximate-nearest-neighbors-algorithms-8b94f057d6b6#:~:text=Approximate%20Nearest%20Neighbor%20techniques%20speed,dimensionality%20reduction%20and%20vector%20rotation.

In [2]:
import pickle
import faiss
import annoy

In [3]:
def load_data():
    with open('../Datasets/movies.pickle', 'rb') as f:
        data = pickle.load(f)
    return data

data = load_data()

print(len(data['name']))
print(data['name'][:5])
print(data['vector'][0][:5])

1682
['Toy Story (1995)' 'GoldenEye (1995)' 'Four Rooms (1995)'
 'Get Shorty (1995)' 'Copycat (1995)']
[-0.01780608 -0.14265831  0.10308606 -0.41564542  0.13982998]


In [4]:
class ExactIndex():
    
    def __init__(self, vectors, labels):
        self.dimension = vectors.shape[1]
        self.vectors = vectors.astype('float32')
        self.labels = labels

    def build(self):
        self.index = faiss.IndexFlatL2(self.dimension,)
        self.index.add(self.vectors)

    def query(self, vectors, k=10):
        distances, indices = self.index.search(vectors, k) 
        # I expect only query on one vector thus the slice
        return [self.labels[i] for i in indices[0]]
    
index = ExactIndex(data["vector"], data["name"])
index.build()

In [5]:
index.query(data['vector'])

['Toy Story (1995)',
 'Rock, The (1996)',
 'Return of the Jedi (1983)',
 'Willy Wonka and the Chocolate Factory (1971)',
 'Phenomenon (1996)',
 'Star Trek: First Contact (1996)',
 'Star Wars (1977)',
 'Hunchback of Notre Dame, The (1996)',
 'Birdcage, The (1996)',
 'Mars Attacks! (1996)']

In [6]:
class AnnoyIndex():
    def __init__(self, vectors, labels):
        self.dimension = vectors.shape[1]
        self.vectors = vectors.astype('float32')
        self.labels = labels    
   
    def build(self, number_of_trees=5):
        self.index = annoy.AnnoyIndex(self.dimension, metric='angular')
        for i, vec in enumerate(self.vectors):
            self.index.add_item(i, vec.tolist())
        self.index.build(number_of_trees)
        
    def query(self, vector, k=10):
        indices = self.index.get_nns_by_vector(
              vector.tolist(), 
              k)                                           
        return [self.labels[i] for i in indices]
    
index = AnnoyIndex(data["vector"], data["name"])
index.build()

In [7]:
index.query(data['vector'][0])

['Toy Story (1995)',
 'Return of the Jedi (1983)',
 'Star Wars (1977)',
 'Willy Wonka and the Chocolate Factory (1971)',
 'Star Trek: First Contact (1996)',
 'Fargo (1996)',
 'Men in Black (1997)',
 'Aladdin (1992)',
 'Leaving Las Vegas (1995)',
 'Beauty and the Beast (1991)']

In [8]:
index.query(data['vector'][20])

['Muppet Treasure Island (1996)',
 'James and the Giant Peach (1996)',
 'Fantasia (1940)',
 'Father of the Bride Part II (1995)',
 '101 Dalmatians (1996)',
 'Matilda (1996)',
 'That Thing You Do! (1996)',
 'Hunchback of Notre Dame, The (1996)',
 'Mystery Science Theater 3000: The Movie (1996)',
 'Cinderella (1950)']