In [None]:
import pandas as pd
import numpy as np
#
import faiss
import pyterrier as pt
import timeit
import matplotlib.pyplot as plt

In [None]:
if not pt.started():
    pt.init()

In [None]:
# Intersección
def list_intersection(l1, l2):
    s1 = set(l1.tolist()[0])
    s2 = set(l2.tolist()[0])
    s3 = s1.intersection(s2)
    return len(s3) / len(s1), s3

In [None]:
# Calcula similitud por coseno
def cosine_sim(a, b):
    dot_product = np.dot(a, b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    return dot_product / (norm_a * norm_b)

---
### MAIN
---

In [None]:
dataset = pt.get_dataset("vaswani")
print("Corpus Vaswani: %s " % dataset.get_corpus())

In [None]:
documents = pd.DataFrame(dataset.get_corpus_iter())
documents.shape

In [None]:
documents.head()

In [None]:
# Calcular embeddings de los documentos
from sentence_transformers import SentenceTransformer, util

In [None]:
model = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v2')

In [None]:
doc_text = list(documents['text'].head(10))
doc_embeddings = model.encode(doc_text, convert_to_tensor=True)

In [None]:
print (doc_embeddings.shape)

In [None]:
#doc_embeddings

In [None]:
query_text = ["computer electronic"]
query_embeddings = model.encode(query_text, convert_to_tensor=True)

In [None]:
#query_embeddings

In [None]:
qry = list(query_embeddings[0])

In [None]:

for i, doc in enumerate(doc_embeddings):
    doc = list(doc)
    score = cosine_sim(doc, qry)
    print ("Sim doc: ", i, " query: ", score)

In [None]:
query_text

In [None]:
# Cargo los embeddings de todos los docs previamente calculados.
vaswani_docs_embeddings = np.load("../data/vaswani_docs_embeddings-512.npy")
vaswani_docs_embeddings.shape

In [None]:
# Obtengo los 'topics' asociados al corpus
topics = dataset.get_topics()
topics.head()

In [None]:
# Cargo los embeddings de todos los queries previamente calculados.
vaswani_query_embeddings = np.load("../data/vaswani_query_embeddings-512.npy")
vaswani_query_embeddings.shape

### Indexación con FAISS (diferentes índices)
**Más sobre los tipos de índices en FAISS:** https://github.com/facebookresearch/faiss/wiki/Faiss-indexes

**Prueba 1 - Flat Index**  
Recordar: En este tipo de índice se mide la distancia L2 (euclídea) entre el vector de query 
y todos los vectores de documentos almacenados. Es simple y preciso (pero no demasiado rápido).

In [None]:
# Inicialización
d = 512
indexFlat = faiss.IndexFlatL2(d)

# Chequeo cantidad de docs en el índice
indexFlat.ntotal

In [None]:
# Agrego los documentos al índice
%time
indexFlat.add(vaswani_docs_embeddings)
indexFlat.ntotal

In [None]:
indexFlat.is_trained

In [None]:
# Ejemplo de recuperación
k = 10
query_vector = np.array([vaswani_query_embeddings[0]])

In [None]:
%time
DFlat, rsFlat = indexFlat.search(query_vector, k)  # Búsqueda

In [None]:
print(rsFlat)

In [None]:
DFlat

In [None]:
faiss.write_index(indexFlat, "vaswani_faiss_flat.ndx")

**Prueba 2 - IVF Flat Index**  
Recordar: En este tipo de índice se particiona el espacio de búsqueda (nlist) para realizar
un ANN.

In [None]:
nlist = 50  # Cantidad de celdas
base_index  = faiss.IndexFlatL2(d)
indexIVFFlat = faiss.IndexIVFFlat(base_index, d, nlist)

In [None]:
indexIVFFlat.is_trained

In [None]:
# Preparo (train) las estructuras de datos del índice
%time
indexIVFFlat.train(vaswani_docs_embeddings)
indexIVFFlat.ntotal

In [None]:
indexIVFFlat.is_trained

In [None]:
# Agrego los documentos al índice
indexIVFFlat.add(vaswani_docs_embeddings)
indexIVFFlat.ntotal

In [None]:
faiss.write_index(indexIVFFlat, "vaswani_faiss_ivfflat.ndx")

In [None]:
# Ejemplo de recuperación
k = 10
indexIVFFlat.nprobe = 1 # Ámbito de búsqueda = 1 celda
query_vector = np.array([vaswani_query_embeddings[0]])

In [None]:
%time
DIVFFlat, rsIVFFlat = indexIVFFlat.search(query_vector, k)  # Búsqueda

In [None]:
rsIVFFlat

In [None]:
list_intersection(rsFlat, rsIVFFlat)

In [None]:
# Aumentamos el ámbito de búsqueda a 10 celdas
indexIVFFlat.nprobe = 10
DIVFFlat, rsIVFFlat = indexIVFFlat.search(query_vector, k)  # Búsqueda

In [None]:
list_intersection(rsFlat, rsIVFFlat)

In [None]:
# Comparación Flat vs IVFlat para todos los queries

In [None]:
query_vectors = np.array(vaswani_query_embeddings)  # Todos los queries
query_vectors.shape

In [None]:
k = 10
#
t0 = timeit.default_timer()
DFlat, rsFlat = indexFlat.search(query_vectors, k)          # Búsqueda en índice flat (exhaustivo)
t1 = timeit.default_timer()
#
tiempo_exhaustivo = t1-t0
print(f"Elapsed time (exhaustivo): {tiempo_exhaustivo} ms")

In [None]:
indexIVFFlat.nprobe = 10
t0 = timeit.default_timer()
DIVFFlat, rsIVFFlat = indexIVFFlat.search(query_vectors, k)  # Búsqueda en índice flat (particionado)
t1 = timeit.default_timer()
#
tiempo_nprobe10 = t1-t0
print(f"Elapsed time (nprobe=10): {tiempo_nprobe10} ms")

In [None]:
# Tradeoff tiempo/overlap entre índice flat exhaustivo y particionado
search_times    = [tiempo_exhaustivo]
search_overlaps = [1]
labels = ["Exhaustivo"]
#
for nprobe in [1, 10, 20, 30, 40, 50]:
    indexIVFFlat.nprobe = nprobe
    t0 = timeit.default_timer()
    DIVFFlat, rsIVFFlat = indexIVFFlat.search(query_vectors, k)  # Búsqueda en índice flat (particionado)
    t1 = timeit.default_timer()
    # Agrego el tiempo y una etiqueta (nprobe)
    search_times.append(t1-t0)
    labels.append(str(nprobe))
    #
    # Calculo overlap promedio para todas las consultas
    tmp_overlaps = []   
    for i, x in enumerate(rsFlat):
        rs_i1 = set(rsFlat[i])
        rs_i2 = set(rsIVFFlat[i])
        tmp_overlaps.append(len(rs_i1.intersection(rs_i2))/len(rs_i1))
    search_overlaps.append(np.mean(tmp_overlaps))

In [None]:
#search_times

In [None]:
#search_overlaps

In [None]:
import matplotlib.pyplot as plt
markers = ['o--', 'x--', 'x--','x--','x--', 'x--', 'x--']
for i, codec in enumerate(search_overlaps):
    plt.plot(search_times[i], search_overlaps[i], markers[i], markersize=8, label=labels[i])
#
plt.grid()
plt.xlabel("Tiempo (ms)")
plt.ylabel("Overlap (%)")
#
plt.legend(loc=(1.05, 0.7))
plt.show()

### Tarea 
**Explorar el impacto del parámetro nlist (particiones) recuperando con nprobe = 1. Defina usted los valores a probar y ejecute todas las consultas. Reporte un gráfico de tradeoff entre nlist y overlap.**