In [None]:
import os
import time

import faiss
import numpy as np
import pandas as pd
import torch
from datasets import Dataset, load_dataset
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_ds = load_dataset("pietrolesci/ag_news", "concat")["train"]


### Create `faiss.IndexIDMap`

In [None]:
index = faiss.IndexFlatIP(768)
index = faiss.IndexIDMap(index)

count = 0
for batch in tqdm(DataLoader(train_ds["text"], batch_size=1_000)):
    ids = np.array(range(count, count + len(batch)))
    embeddings = model.encode(batch, convert_to_numpy=True).astype(np.float32)

    index.add_with_ids(embeddings, ids)
    count += len(batch)

faiss.write_index(index, "train_ag_news.faiss")

In [None]:
faiss_index = faiss.read_index("train_ag_news.faiss")

In [None]:
query = model.encode(["France", "Italy"])
scores, indices = faiss_index.search(query, 5)

In [None]:
train_ds.select(indices[0])["text"]