# Approximate Nearest Neighbours with FAISS [(credits)](https://github.com/mneedham/LearnDataWithMark/blob/main/faiss-ann/notebooks/ANN-Tutorial.ipynb)
In this notebook, we're going to learn how to do approximate nearest neighbours in FAISS using a cell probe method.

In [6]:
!pip install faiss-cpu pandas numpy

Defaulting to user installation because normal site-packages is not writeable


In [1]:
import faiss
import copy
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly_functions import generate_distinct_colors, zoom_in, create_plot, plot_points

In [2]:
dimensions = 2
number_of_vectors = 10_000
vectors = np.random.random((number_of_vectors, dimensions)).astype(np.float32)

In [3]:
fig = create_plot()
plot_points(fig, points=vectors, color='#CCCCCC', label="Data")
fig

In [4]:
search_vector = np.array([[0.5, 0.5]])

In [5]:
plot_points(fig, points=search_vector, color='black', label="Search Vector", symbol="x", size=10)
fig

## Creating a cell probe index
When creating the index, we need to specify how many partitions (or cells) we want to divide the vector space into.

In [6]:
cells = 10

quantizer = faiss.IndexFlatL2(dimensions)
index = faiss.IndexIVFFlat(quantizer, dimensions, cells)

In [7]:
index.train(vectors)

In [8]:
centroids = index.quantizer.reconstruct_n(0, index.nlist)
centroids

array([[0.16363522, 0.15642376],
       [0.5535022 , 0.8781461 ],
       [0.46925026, 0.3792651 ],
       [0.14681713, 0.86282074],
       [0.87329197, 0.7753164 ],
       [0.55069596, 0.12103064],
       [0.14718491, 0.5009223 ],
       [0.6540719 , 0.57730514],
       [0.3552203 , 0.6984649 ],
       [0.86214733, 0.24901374]], dtype=float32)

## Visualising cells and centroids
Let's update our chart to show the centroids and to which cell each vector will be assigned.

In [9]:
_, cell_ids = index.quantizer.search(vectors, k=1)
cell_ids = cell_ids.flatten()
cell_ids[:10]

array([5, 0, 1, 7, 7, 9, 4, 2, 5, 9])

In [10]:
color_map = generate_distinct_colors(index.nlist)

fig_cells = create_plot()

unique_ids = np.unique(cell_ids)
for uid in unique_ids:
  mask = (cell_ids == uid)
  masked_vectors = vectors[mask]
  plot_points(fig_cells, masked_vectors, color_map[uid], "Cell {}".format(uid), size=6)

plot_points(fig_cells, centroids, symbol="diamond-tall", color="black", size=15, showlegend=False)
plot_points(fig_cells, search_vector, symbol="x", color="black", size=15, label="Search Vector")

fig_cells

In [19]:
!pip install -U kaleido

Defaulting to user installation because normal site-packages is not writeable


In [None]:
import plotly.io as pio

current_width = fig_cells.layout.width
current_height = fig_cells.layout.height
print(current_width, current_height)


pio.orca.config.executable = '/usr/local/bin/orca'  # Replace with the actual path to your Orca executable


desired_width = 1600  # or whatever width you prefer in pixels
scaling_factor = desired_width / current_width if current_width else None

# If you know the current width, scale the height proportionally.
desired_height = current_height * scaling_factor if current_height and scaling_factor else None
fig_cells.update_layout(showlegend=False)
fig_cells.update_layout(template="plotly_white", margin=dict(t=0, b=0, l=0, r=0))
pio.write_image(fig_cells, './img/figure_high_res.png', width=1280, height=720, scale=3.0)


None None


## Searching for our vector
Let's add the vectors to the index and look for our search vector.

In [12]:
index.add(vectors)

When using a cell probe index, we can specify how many cells we want to use in the search. More cells will mean a slower, but potentially more accurate search.

In [13]:
index.nprobe

1

In [14]:
%%time 
distances, indices = index.search(search_vector, k=10)

df_ann = pd.DataFrame({
  "id": indices[0],
  "vector": [vectors[id] for id in indices[0]],
  "distance": distances[0],
})
df_ann

CPU times: user 1.01 ms, sys: 1.75 ms, total: 2.77 ms
Wall time: 4.56 ms


Unnamed: 0,id,vector,distance
0,2181,"[0.5013984, 0.5065762]",4.5e-05
1,1458,"[0.5007656, 0.49282777]",5.2e-05
2,3140,"[0.49461138, 0.49191767]",9.4e-05
3,5725,"[0.5127078, 0.49835032]",0.000164
4,2381,"[0.5044129, 0.5135126]",0.000202
5,8035,"[0.5052558, 0.4835363]",0.000299
6,4715,"[0.5197053, 0.5003462]",0.000388
7,9053,"[0.5185322, 0.49210975]",0.000406
8,1102,"[0.52128834, 0.48609325]",0.000647
9,5476,"[0.520393, 0.5152053]",0.000647


In [15]:
_, search_vectors_cell_ids = index.quantizer.search(search_vector, k=1)
unique_searched_ids = search_vectors_cell_ids[0]
unique_searched_ids

array([2])

In [17]:
fig_search = create_plot()

for uid in unique_searched_ids:
  mask = (cell_ids == uid)
  masked_vectors = vectors[mask]
  plot_points(fig_search, masked_vectors, color_map[uid], label="Cell {}".format(uid))
  plot_points(fig_search, centroids[uid].reshape(1, -1), symbol="diamond-tall", color="black", size=10, label="Centroid for Cell {}".format(uid), showlegend=False)

plot_points(fig_search, points=search_vector, color='red', label="Search Vector", symbol="x", size=10)

ann_vectors = np.array(df_ann["vector"].tolist())
plot_points(fig_search, points=ann_vectors, color='black', label="Approx Nearest Neighbors")

fig_search

# Brute Force Nearest Neighbours
How well did this approach work compared to a brute force one?

In [18]:
brute_force_index = faiss.IndexFlatL2(dimensions)
brute_force_index.add(vectors)

In [20]:
%%time
distances, indices = brute_force_index.search(search_vector, k=10)

pd.DataFrame({
  "id": indices[0],
  "vector": [vectors[id] for id in indices[0]],
  "distance": distances[0],
  "cell": [cell_ids[id] for id in indices[0]]
})

CPU times: user 1.17 ms, sys: 883 µs, total: 2.05 ms
Wall time: 1.65 ms


Unnamed: 0,id,vector,distance,cell
0,2181,"[0.5013984, 0.5065762]",4.5e-05,2
1,1458,"[0.5007656, 0.49282777]",5.2e-05,2
2,3140,"[0.49461138, 0.49191767]",9.4e-05,2
3,5725,"[0.5127078, 0.49835032]",0.000164,2
4,2381,"[0.5044129, 0.5135126]",0.000202,2
5,8035,"[0.5052558, 0.4835363]",0.000299,2
6,4715,"[0.5197053, 0.5003462]",0.000388,2
7,9053,"[0.5185322, 0.49210975]",0.000406,2
8,1102,"[0.52128834, 0.48609325]",0.000647,2
9,5476,"[0.520393, 0.5152053]",0.000647,2


In [21]:
index.nprobe = 2

In [22]:
index.quantizer.search(search_vector, k=2)

(array([[0.01552246, 0.02971424]], dtype=float32), array([[2, 7]]))

In [23]:
%%time
distances, indices = index.search(search_vector, k=10)

pd.DataFrame({
  "id": indices[0],
  "vector": [vectors[id] for id in indices[0]],
  "distance": distances[0],
  "cell": [cell_ids[id] for id in indices[0]]
})

CPU times: user 1.01 ms, sys: 344 µs, total: 1.35 ms
Wall time: 1.13 ms


Unnamed: 0,id,vector,distance,cell
0,2181,"[0.5013984, 0.5065762]",4.5e-05,2
1,1458,"[0.5007656, 0.49282777]",5.2e-05,2
2,3140,"[0.49461138, 0.49191767]",9.4e-05,2
3,5725,"[0.5127078, 0.49835032]",0.000164,2
4,2381,"[0.5044129, 0.5135126]",0.000202,2
5,8035,"[0.5052558, 0.4835363]",0.000299,2
6,4715,"[0.5197053, 0.5003462]",0.000388,2
7,9053,"[0.5185322, 0.49210975]",0.000406,2
8,1102,"[0.52128834, 0.48609325]",0.000647,2
9,5476,"[0.520393, 0.5152053]",0.000647,2
