# 0. Env

In [None]:
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import faiss

In [None]:
# 속도 비교 용 dictionary
res_time_dict = {}

# 1. Context & Query Vector
- 성능 확인을 위해서 다수의 context 벡터와, query 벡터를 가정

In [None]:
d = 256                 # 벡터 차원수
nb = 1000000           # context vector size
nq = 1000               # query vector size
np.random.seed(1234)    # random seed

In [None]:
# context vector
xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.

In [None]:
# query vector
xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.

# 2. Flat: 브루트포스

In [None]:
# L2 index 생성
index = faiss.IndexFlatL2(d)
index.is_trained

In [None]:
# context embedding 추가
index.add(xb)
index.ntotal

In [None]:
k = 10

start = time.time()
D, I = index.search(xq, k)
end = time.time()

print(I)
res_time_dict['Flat'] = end - start
print('total time:', (end - start))

# 3. LSH: Locality Sensitive Hashing

In [None]:
nbits = d // 4  # resolution of bucketed vectors
# initialize index and add vectors
index = faiss.IndexLSH(d, nbits)
index.add(xb)

In [None]:
k = 10

start = time.time()
D, I = index.search(xq, k)
end = time.time()

print(I)
res_time_dict['LSH'] = end - start
print('total time:', (end - start))

# 4. HNSW: Hierarchical Navigable Small World Graphs

In [None]:
# set HNSW index parameters
M = 64  # number of connections each vertex will have
ef_search = 32  # depth of layers explored during search
ef_construction = 64  # depth of layers explored during index construction

In [None]:
# 그래프 생성 시간이 오래 걸림 (10만 데이터만 사용)
# initialize index
index = faiss.IndexHNSWFlat(d, M)
# set efConstruction and efSearch parameters
index.hnsw.efConstruction = ef_construction
index.hnsw.efSearch = ef_search
# add data to index
index.add(xb[:100000])

In [None]:
k = 10

start = time.time()
D, I = index.search(xq, k)
end = time.time()

print(I)
res_time_dict['HNSW'] = end - start
print('total time:', (end - start))

# 5. Inverted File Index

## 5.1. IndexIVFFlat query

In [None]:
# IndexIVFFlat partitioning
nlist = 50 # how many cells
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [None]:
# check trained
index.is_trained

In [None]:
index.train(xb)
index.is_trained  # check if index is now trained

In [None]:
index.add(xb)
index.ntotal  # number of embeddings indexed

In [None]:
index.nprobe

In [None]:
k = 10

start = time.time()
D, I = index.search(xq, k)
end = time.time()

print(I)
res_time_dict['IVF.1'] = end - start
print('total time:', (end - start))

## 5.2. Increase probs

In [None]:
# increase number of probe
index.nprobe = 10

In [None]:
k = 10

start = time.time()
D, I = index.search(xq, k)
end = time.time()

print(I)
res_time_dict['IVF.10'] = end - start
print('total time:', (end - start))

# 6. Product Quantization

In [None]:
# make IndexIVFPQ index
m = 8  # number of centroid IDs in final compressed vectors
bits = 8 # number of bits in each centroid

quantizer = faiss.IndexFlatL2(d)  # we keep the same L2 distance flat index
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)

In [None]:
# check trained
index.is_trained

In [None]:
# train
index.train(xb)
index.is_trained  # check if index is now trained

In [None]:
index.add(xb)
index.ntotal  # number of embeddings indexed

In [None]:
index.nprobe = 10  # align to previous IndexIVFFlat nprobe value

In [None]:
k = 10

start = time.time()
D, I = index.search(xq, k)
end = time.time()

print(I)
res_time_dict['PQ'] = end - start
print('total time:', (end - start))

# 7. Query On GPU

In [None]:
# L2 index 생성
index = faiss.IndexFlatL2(d)
index.is_trained

In [None]:
res = faiss.StandardGpuResources()                # GPU 리소스 객체 생성
gpu_index = faiss.index_cpu_to_gpu(res, 0, index) # 인덱스를 GPU로 전송

In [None]:
# context embedding 추가
gpu_index.add(xb)
gpu_index.ntotal

In [None]:
k = 10

start = time.time()
D, I = index.search(xq, k)
end = time.time()

print(I)
res_time_dict['GPU'] = end - start
print('total time:', (end - start))

# 8. Visualization

In [None]:
data = {'algorithm': list(res_time_dict.keys()),
        'res time': list(res_time_dict.values())}
df = pd.DataFrame.from_dict(data)
df

In [None]:
df.plot.bar(x='algorithm', y='res time')
plt.show()