In [None]:
# The first thing we need is data, we’ll be concatenating several datasets from this semantic test similarity hub repo.
# We will download each dataset, and extract the relevant text columns into a single list.

import requests
from io import StringIO
import pandas as pd

res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
# create dataframe
data = pd.read_csv(StringIO(res.text), sep='\t')
data.head()

In [None]:
# we take all samples from both sentence A and B
sentences = data['sentence_A'].tolist()
sentence_b = data['sentence_B'].tolist()
sentences.extend(sentence_b)  # merge them
len(set(sentences))  # together we have ~4.5K unique sentences

In [None]:
# This isn't a particularly large number, so let's pull in a few more similar datasets.
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]
# each of these dataset have the same structure, so we loop through each creating our sentences data
for url in urls:
    res = requests.get(url)
    # extract to dataframe
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)
    # add to columns 1 and 2 to sentences list
    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())

# remove duplicates and NaN
sentences = [word for word in list(set(sentences)) if type(word) is str]
len(set(sentences))

In [4]:
# Finally, we build our dense vector representations of each sentence using the sentence-BERT library.

from sentence_transformers import SentenceTransformer
# initialize sentence transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')
# create sentence embeddings
sentence_embeddings = model.encode(sentences)
sentence_embeddings.shape

(14504, 768)

In [5]:
# Flat index does not require training
import faiss
d = sentence_embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.is_trained

True

In [6]:
# Once ready, we load our embeddings and query like so:
index.add(sentence_embeddings)
index.ntotal

14504

In [7]:
# Then search given a query xq and number of nearest neigbors to return k.
k = 4
xq = model.encode(["Someone sprints with a football"])

In [8]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 4343 11083 10232  3559]]
CPU times: user 11.9 ms, sys: 95 µs, total: 12 ms
Wall time: 10.6 ms


In [9]:
# Here we're returning indices [ 4343 11083 10232  3559]. The search takes around 12 ms.
answers_index = [4343, 11083, 10232, 3559]
for i in answers_index:
    print(sentences[i])

A group of football players is running in the field
A group of people playing football is running in the field
Two groups of people are playing football
A person playing football is running past an official carrying a football


In [None]:
# Now, if we’d rather extract the numerical vectors from Faiss, we can do that too.

# we have 4 vectors to return (k) - so we initialize a zero array to hold them
import numpy as np
vecs = np.zeros((k, d))
# then iterate through each ID from I and add the reconstructed vector to our zero-array
for i, val in enumerate(I[0].tolist()):
    vecs[i, :] = index.reconstruct(val)
vecs[0][:100]

In [11]:
# IndexFlatL2 does a brute force so it will be slower when vectors are too many. So lets use the Voronoi one (IndexIVFFlat)
nlist = 50  # how many cells
# The quantizer index is used to search inside a cell
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [12]:
# But unlike the flat index, this time we need to train
print(index.is_trained)

index.train(sentence_embeddings)
print(index.is_trained)  # check if index is now trained

index.add(sentence_embeddings)
print(index.ntotal)  # number of embeddings indexed

False
True
14504


In [13]:
%%time
# Let’s search again using the same indexed sentence embeddings and the same query vector xq.
# This time the search time decreased considerably and we get slight
D, I = index.search(xq, k)
print(I)

[[10232  3559  5912   880]]
CPU times: user 830 µs, sys: 237 µs, total: 1.07 ms
Wall time: 509 µs


In [16]:
# Here we're returning indices [10232  3559  5912   880]. The search takes around 1.07 ms. Its not as good as the
# brute force one but still very good.
answers_index = [10232, 3559, 5912, 880]
for i in answers_index:
    print(sentences[i])

Two groups of people are playing football
A person playing football is running past an official carrying a football
A football player kicks the ball.
A football player is running past an official carrying a football


In [17]:
# To get better results, we can increase the nprobe: The number of cells to search.
index.nprobe = 10

In [18]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 4343 11083 10232  3559]]
CPU times: user 1.84 ms, sys: 3.3 ms, total: 5.14 ms
Wall time: 3.9 ms


In [20]:
# The search time increased to 5.14 ms but we get the exact same result as the flat index [ 4343 11083 10232  3559].
# Lets try one final trick (aka product Quantization or PQ). Lets try that

m = 8  # number of centroid IDs in final compressed vectors (d % m == 0)
bits = 8 # number of bits in each centroid

quantizer = faiss.IndexFlatL2(d)  # we keep the same L2 distance flat index
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)

index.is_trained

False

In [21]:
index.train(sentence_embeddings)
index.add(sentence_embeddings)

In [22]:
index.nprobe = 10  # align to previous IndexIVFFlat nprobe value

In [23]:
%%time
D, I = index.search(xq, k)
print(I)

[[ 8401  8562 11083  4343]]
CPU times: user 1.56 ms, sys: 0 ns, total: 1.56 ms
Wall time: 864 µs


In [24]:
# Here we're returning indices [ 8401  8562 11083  4343]. The search only took around 1.56 ms.
answers_index = [8401, 8562, 11083, 4343]
for i in answers_index:
    print(sentences[i])

A group of football players running down the field.
Football players are on the field.
A group of people playing football is running in the field
A group of football players is running in the field
