In [1]:
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

documents = [
    "That restaurant was not as good as the last movie I watched.",
    "I'm selling a used car in good condition",
    "Food was okay, the rest so so",
    "I love cats, but don't really like hyenas",
    "On the road, you must be careful",
]

vectors = [
  # tokenize the document, return it as PyTorch tensors (vectors),
  # and pass it onto the model
  model(**tokenizer(document, return_tensors='pt'))[0].detach().squeeze()
  for document in documents
]

[v.size() for v in vectors]

Downloading:   0%|          | 0.00/442 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

[torch.Size([15, 768]),
 torch.Size([12, 768]),
 torch.Size([10, 768]),
 torch.Size([15, 768]),
 torch.Size([10, 768])]

In [2]:
import torch

averaged_vectors = [torch.mean(vector, dim=0) for vector in vectors]

[v.size() for v in averaged_vectors]

[torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768])]

In [3]:
def encode(document: str) -> torch.Tensor:
    tokens = tokenizer(document, return_tensors='pt')
    vector = model(**tokens)[0].detach().squeeze()
    return torch.mean(vector, dim=0)

In [14]:
import faiss
import numpy as np

index = faiss.IndexIDMap(faiss.IndexFlatIP(768)) # the size of our vector space
# index all the documents, we need them as numpy arrays first
index.add_with_ids(
    np.array([t.numpy() for t in averaged_vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(documents))))

def search(query: str, k=1):
    encoded_query = encode(query).unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    scores = top_k[0][0]
    results = [documents[_id] for _id in top_k[1][0]]
    return list(zip(results, scores))

In [15]:
documents[1]

search(documents[1], k=2)

[('On the road, you must be careful', -3.4028235e+38),
 ('On the road, you must be careful', -3.4028235e+38)]

In [16]:
search("I know how to drive", k=2)

[('On the road, you must be careful', -3.4028235e+38),
 ('On the road, you must be careful', -3.4028235e+38)]