# Lesson Notebook 10 - Embedding-Based Retrieval

In this notebook, we'll explore retrieving and ranking news headlines in response to a query. We'll use an encoder model similar to the Universal Sentence Encoder, to create vectors for each headline and then a vector for the query.  We'll use a library called [SentenceTransformers](https://www.sbert.net/) that has a large number of underlying model weight sets on Hugging Face.  Sentence Transformers are designed to take a sequence of words like a sentence as input and generates an representative vector, an embedding, as output. Note that Sentence Transformers are only available in PyTorch but that won't affect our use here thanks to the HuggingFace API.

First, we'll generate vectors for our headlines and hold those. Then we'll generate an embedding for our query and we'll just use Nearest Neighbors search on the full set of news headlines. Finally, we'll cluster the news headline embeddings first, and only apply Nearest Neighbors to the top k clusters whose centroids are most similar to the query embedding.

If we were trying to build a system that needed to scale,we would use something like the ScaNN library to hold our embeddings and perform our searches.

<a id = 'returnToTop'></a>

## Notebook Contents

  * 1. [Setup](#setup)
  * 2. [Data Preparation](#dataPrep)
  * 3. [Encode Embeddings](#encodeData)
  * 4. [Query and Retrieval](#queryRet)
  * 5. [Retrieval via Clusters](#clusterRet)
  * 6. [Answers](#answers)      









[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/datasci-w266/2024-spring-main/blob/master/materials/lesson_notebooks/lesson_10_embedding_based_retrieval.ipynb)

[Return to Top](#returnToTop)  
<a id = 'setup'></a>

### 1. Setup

In [1]:
!pip install -q -U sentence-transformers
!pip install -q  -U datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m66.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m53.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m49.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for sentence-transformers (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━

In [2]:
import os
import time
import numpy as np
from datasets import load_dataset

from scipy.spatial.distance import cosine
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer

[Return to Top](#returnToTop)  
<a id = 'dataPrep'></a>

### 2. Data Preparation

For our data we'll use the test portion of the XSum sumarization data set.  The goal of XSum is to generate a one line summary of the input article.  We'll grab the 'summary' field as this will be an excellent set of "sentences" for our retrieval experiment.  It takes about a minute to process the data and get us the test records.

In [3]:
dataset = load_dataset('xsum', split='test')

Downloading builder script:   0%|          | 0.00/5.76k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

In [4]:
len(dataset)

11334

In [5]:
dataset[0]

{'document': 'Prison Link Cymru had 1,099 referrals in 2015-16 and said some ex-offenders were living rough for up to a year before finding suitable accommodation.\nWorkers at the charity claim investment in housing would be cheaper than jailing homeless repeat offenders.\nThe Welsh Government said more people than ever were getting help to address housing problems.\nChanges to the Housing Act in Wales, introduced in 2015, removed the right for prison leavers to be given priority for accommodation.\nPrison Link Cymru, which helps people find accommodation after their release, said things were generally good for women because issues such as children or domestic violence were now considered.\nHowever, the same could not be said for men, the charity said, because issues which often affect them, such as post traumatic stress disorder or drug dependency, were often viewed as less of a priority.\nAndrew Stevens, who works in Welsh prisons trying to secure housing for prison leavers, said the

[Return to Top](#returnToTop)  
<a id = 'encodeData'></a>

### 3. Encode embeddings

We'll load the sentence transformers with a smaller model so that it can run quickly in the live session.  You can experiment with others to see the tradeoff between size, processing time, and quality.  For example, you could load the sentence transformer with `'sentence-transformers/all-roberta-large-v1'` and leverage the improvements that come with using a large RoBERTa model.  You can see [a full listing of models](https://huggingface.co/models?library=sentence-transformers&sort=downloads) at HuggingFace.  We'll use the checkpoint based on [this paper](https://arxiv.org/pdf/2006.03659.pdf).

In [7]:
encoder_model = SentenceTransformer('sentence-transformers/LaBSE')

Downloading (…)be010/.gitattributes:   0%|          | 0.00/391 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)/2_Dense/config.json:   0%|          | 0.00/114 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

Downloading (…)168ebbe010/README.md:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

Downloading (…)8ebbe010/config.json:   0%|          | 0.00/804 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)be010/tokenizer.json:   0%|          | 0.00/9.62M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

Downloading (…)168ebbe010/vocab.txt:   0%|          | 0.00/5.22M [00:00<?, ?B/s]

Downloading (…)ebbe010/modules.json:   0%|          | 0.00/461 [00:00<?, ?B/s]

In [8]:
# Just encoding a subset so it doesn't take too long during the live session (24 seconds)

news_headlines = [x['summary'] for x in dataset]
news_embeddings = encoder_model.encode(news_headlines[:7500])

[Return to Top](#returnToTop)  
<a id = 'queryRet'></a>

### 4. Query and Retrieval

First, we'll need to create a query and generate an embedding to represent it. Then we'll use that query embedding to walk through *all* of the headline embeddings and find the 10 nearest neighbors.  This is a non-scalable approach to retrieval.  We can't compare our query embedding with all of the news headline embeddings each time we have a query.

In [9]:
# Let's try a query for some news we might be looking for

query = 'Tiger Woods did not make the cut at golf tournament.'
query_embedding = encoder_model.encode([query])

In [10]:
# We'll start by loading all of the news embeddings into a Nearest Neighbors model

knn_model = NearestNeighbors(n_neighbors=10)
knn_model.fit(news_embeddings)

In [11]:
# We'll keep track of the time it takes to find the top 10 nearest headlines

start = time.time()
dists, topk_idx = knn_model.kneighbors(query_embedding)
for d, i in zip(dists[0], topk_idx[0]):
    print(d, news_headlines[i])

print('\nTime:', time.time() - start)

# (We're using a small number of headlines so it's fast for the live session,
# but it'll still go even faster if we narrow the likely candidates first.)

0.9218927478698277 Tiger Woods missed the cut at the Farmers Insurance Open, as England's Justin Rose maintained a one-shot lead.
1.0175832210127616 English rider Guy Martin will not compete in the Ulster Grand Prix at Dundrod for the third year in a row.
1.0483298938636172 Greg Dyke will not seek re-election as Football Association chairman when his term ends in June.
1.0495320094635274 Tiger Woods admits he has concerns over the physical challenge of stepping up his return from long-term injury.
1.0528629486325418 Six-time champion Steve Davis failed to reach the World Championship as he lost 10-4 to Fergal O'Brien in the first round of qualifying in Sheffield.
1.0567675389137763 Stephen Maguire said he was "embarrassed" at not being able to motivate himself for the World Championship at the Crucible.
1.0629106737718323 A golfer has suffered leg injuries after being bitten by a crocodile on an Australian golf course.
1.0718270605282665 You can't win the Davis Cup on your own and reac

Because of our small number of headlines we can get the 10 closest headlines in 6 hundredths of a second.

[Return to Top](#returnToTop)  
<a id = 'clusterRet'></a>

### 5. Retrieval via Clusters

If we can cluster the document embeddings first then we can speed up and scale the retrieval process.  We can first find clusters that are "close" to our query. Then we can actually examine (and score) all of the document embeddings within the one cluster that seems responsive to the query.

In [12]:
# Now let's try clustering the news headlines beforehand. This takes time,
# but we only need to do it once, then re-use it for different queries.

cluster_model = KMeans(n_clusters=50)
news_clusters = cluster_model.fit_predict(news_embeddings)



In [13]:
cluster_news_ids = {i: [] for i in range(50)}
for i, c in enumerate(news_clusters):
    cluster_news_ids[c].append(i)

In [14]:
# Compute the distance from the query embedding to each cluster centroid

query_cluster_dists = [cosine(query_embedding[0], cluster_model.cluster_centers_[c])
                       for c in range(50)]

In [15]:
# Get the top k nearest clusters and retrieve their document ids
# (You can try different numbers of top clusters, to see the trade-off between
# speed and recall of all the best articles we found above.)

top_clusters = np.argsort(query_cluster_dists)[:2]
candidate_news_ids = [i for c in top_clusters for i in cluster_news_ids[c]]
len(candidate_news_ids)

319

In [16]:
# Now use Nearest Neighbors only on the top cluster candidates

candidate_news_embeds = [news_embeddings[i] for i in candidate_news_ids]

knn_model = NearestNeighbors(n_neighbors=10)
knn_model.fit(candidate_news_embeds)

start = time.time()
dists, topk_idx = knn_model.kneighbors(query_embedding)
for d, i in zip(dists[0], topk_idx[0]):
    orig_i = candidate_news_ids[i]
    print(d, news_headlines[orig_i])

print('\nTime:', time.time() - start)

1.0495320094635274 Tiger Woods admits he has concerns over the physical challenge of stepping up his return from long-term injury.
1.0528629486325418 Six-time champion Steve Davis failed to reach the World Championship as he lost 10-4 to Fergal O'Brien in the first round of qualifying in Sheffield.
1.0897285646665726 Former world number one Tiger Woods says he is getting "professional help" to manage medication for pain and sleep loss as he tries to return to fitness.
1.0906746817479285 Wales wing George North's decision not to return to a Welsh region is a disappointment, says Rugby Wales chief executive Mark Davies.
1.1027796628910824 Dutchman Dylan Groenewegen claimed the opening stage victory in this year's Tour de Yorkshire.
1.1104867227199433 Teenage jockey David Mullins said he has "never had a feeling like it" after winning the Grand National on Rule The World at Aintree.
1.1106215444842729 Wales rugby great Gareth Edwards has been knighted by the Duke of Cambridge in recogniti

The clustered approach provides equally good results and it only takes one one hundreth of a second.  That time savings will be meaningful when we have millions or billions of records that need to be searched.

In practice, instead of the clustering approach you would want to use something like [ScaNN](https://github.com/google-research/google-research/tree/master/scann).