# Recommendation using embeddings and nearest neighbor search

Recommendations are widespread across the web.

- 'Bought that item? Try these similar items.'
- 'Enjoy that book? Try these similar titles.'
- 'Not the help page you were looking for? Try these similar pages.'

This notebook demonstrates how to use embeddings to find similar items to recommend. In particular, we use [AG's corpus of news articles](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html) as our dataset.

Our model will answer the question: given an article, what other articles are most similar to it?

### 1. Imports

First, let's import the packages and functions we'll need for later. If you don't have these, you'll need to install them. You can install them via your terminal by running `pip install {package_name}`, e.g. `pip install pandas`.

In [1]:
# imports
import pandas as pd
import pickle

from openai.embeddings_utils import (
    get_embedding,
    distances_from_embeddings,
    tsne_components_from_embeddings,        # altered by me
    chart_from_components,
    chart_from_components_3D,
    indices_of_nearest_neighbors_from_distances,
)

# constants
EMBEDDING_MODEL = "text-embedding-ada-002"

### 2. Load data

Next, let's load the AG news data and see what it looks like.
Note that this is where we would alter the code to load something else.

In [126]:
# load data (full dataset available at http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
dataset_path = "data/AG_news_samples.csv"

# we can load other csv files as well, in this case one consisting of bank transactions
# dataset_path = "/Users/edsil/Downloads/miDataTransactions.csv"
df = pd.read_csv(dataset_path)

# print dataframe
n_examples = 5
df.head(n_examples)


Unnamed: 0,title,description,label_int,label
0,World Briefings,BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime M...,1,World
1,Nvidia Puts a Firewall on a Motherboard (PC Wo...,PC World - Upcoming chip set will include buil...,4,Sci/Tech
2,"Olympic joy in Greek, Chinese press",Newspapers in Greece reflect a mixture of exhi...,2,Sports
3,U2 Can iPod with Pictures,"SAN JOSE, Calif. -- Apple Computer (Quote, Cha...",4,Sci/Tech
4,The Dream Factory,"Any product, any shape, any size -- manufactur...",4,Sci/Tech


Let's take a look at those same examples, but not truncated by ellipses.

In [127]:
# print the title, description, and label of each example in the DataFrame
print(df.index) # this tells us the size of the DF
print(df[:5])   # this tells us the first 5 elements of DF with their headings
# this displays the full contents of each field omitting the label_int field
for idx, row in df.head(n_examples).iterrows():
    print("")
    print(f"Title: {row['title']}")
    print(f"Description: {row['description']}")
    print(f"Label: {row['label']}")

    '''print("")
    print(f"Date: {row[df.columns[0]]}")
    print(f"Type: {row[df.columns[1]]}")
    print(f"Merchant: {row[df.columns[2]]}")
    print(f"Credit/Debit: {row[df.columns[3]]}")
    print(f"Balance: {row[df.columns[4]]}")'''

RangeIndex(start=0, stop=2000, step=1)
                                               title  \
0                                    World Briefings   
1  Nvidia Puts a Firewall on a Motherboard (PC Wo...   
2                Olympic joy in Greek, Chinese press   
3                          U2 Can iPod with Pictures   
4                                  The Dream Factory   

                                         description  label_int     label  
0  BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime M...          1     World  
1  PC World - Upcoming chip set will include buil...          4  Sci/Tech  
2  Newspapers in Greece reflect a mixture of exhi...          2    Sports  
3  SAN JOSE, Calif. -- Apple Computer (Quote, Cha...          4  Sci/Tech  
4  Any product, any shape, any size -- manufactur...          4  Sci/Tech  

Title: World Briefings
Description: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dir

### 3. Build cache to save embeddings

Before getting embeddings for these articles, let's set up a cache to save the embeddings we generate. In general, it's a good idea to save your embeddings so you can re-use them later. If you don't save them, you'll pay again each time you compute them again.

The cache is a dictionary that maps tuples of `(text, model)` to an embedding, which is a list of floats. The cache is saved as a Python pickle file.

In [128]:
# note that again this cell doesn't do anything until it is called from somewhere, including the printing
# establish a cache of embeddings to avoid recomputing
# cache is a dict of tuples (text, model) -> embedding, saved as a pickle file

# set path to embedding cache (what does this data look like? VSC won't open it - see above for the content interrogation)
# embedding_cache_path = "data/recommendations_embeddings_cache.pkl"
embedding_cache_path = "data/bankdatafortesting.pkl"

# load the cache if it exists, and save a copy to disk
try:
    embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
    embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
    pickle.dump(embedding_cache, embedding_cache_file)

# define a function to retrieve embeddings from the cache if present, and otherwise request via the API, which will incur significant token costs; EITHER WAY this returns the embedding
def embedding_from_string(
    string: str,
    model: str = EMBEDDING_MODEL,
    embedding_cache=embedding_cache
) -> list:
    """
    Return embedding of given string, using a cache to avoid recomputing.
    embedding_cache is a dictionary and so has keys and values
    dict.keys() is a LIST and a very long one at that!!!
    """
    if (string, model) not in embedding_cache.keys():
        embedding_cache[(string, model)] = get_embedding(string, model)
        with open(embedding_cache_path, "wb") as embedding_cache_file:
            pickle.dump(embedding_cache, embedding_cache_file)
    return embedding_cache[(string, model)]

Let's check that it works by getting an embedding.

In [129]:
# as an example, take the first description from the dataset
example_string = df["description"].values[0]
print(f"\nExample string: {example_string}")

# print the first 10 dimensions of the embedding
example_embedding = embedding_from_string(example_string)
print(f"\nExample embedding: {example_embedding[:10]}...")
print(f"\nLength of embedding: {len(example_embedding)} ...")


Example string: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.


RetryError: RetryError[<Future at 0x7f9a8793a820 state=finished raised AuthenticationError>]

### 4. Recommend similar articles based on embeddings

To find similar articles, let's follow a three-step plan (note that this cell is called from the following cell):
1. Get the similarity embeddings of all the article descriptions
2. Calculate the distance between a source title and all other articles
3. Print out the other articles closest to the source title (although this happens in the next cell)

In [None]:
def print_recommendations_from_strings(
    strings: list[str],
    index_of_source_string: int,
    k_nearest_neighbors: int = 1,
    model=EMBEDDING_MODEL,
) -> list[int]:
    """Print out the k nearest neighbors of a given string."""
    # get embeddings for all strings
    print(f"There are {len(strings):5.0f} strings in the database, all with previously-calculated embeddings.")
    # this either retrieves the embedding from a cached pickle file or recalculates it using the current engine
    embeddings = [embedding_from_string(
        string, model=model) for string in strings]
    # get the embedding of the source string
    query_embedding = embeddings[index_of_source_string]
    # get distances between the source embedding and other embeddings
    # (functions from embeddings_utils.py)
    # it's unfortunate that they use 'distance' when they are measuring angular deviation using cosines
    # query-embedding is the source string, here the one about Blair who embedding was calculated above
    # this distance_from_embeddings just calculates the normalised scalar/dot product
    # here is a simple test/confirmation
    distances = distances_from_embeddings([1.0,2.0,3.0],[[1.0,2.0,3.0],[4.0,5.0,6.0],[-1.0,-2.0,-6.0]],distance_metric="cosine")
    print(distances)
    distances = distances_from_embeddings(
        query_embedding, embeddings, distance_metric="cosine")
    # get indices of nearest neighbors (function from embeddings_utils.py)
    print(distances[:10])   # these are the first 10 raw distances from the previous line;
                            # smaller means closer so these are all too big to be included in the nearest list
    indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(
        distances)
    print("Returned in increasing order of distance but numbered by their position in the original string table \n",indices_of_nearest_neighbors[:10],'\n')
    for i in range(20):
        posn = indices_of_nearest_neighbors[i]  # we have to use an indirect reference
        print(f"Rank {i:4.0f}; Distance coefficient: {distances[posn]:.3f} & {indices_of_nearest_neighbors[i]:4.0f} of nearest neighbours")
        # these are INTEGERS almost like tokens
                                                # I think they represent the RANKINGS by 'distance' 
    # print out source string
    query_string = strings[index_of_source_string]
    print(f"Source string: {query_string}")
    # print out its k nearest neighbors
    k_counter = 0
    for i in indices_of_nearest_neighbors:
        # skip any strings that are identical matches to the starting string
        if query_string == strings[i]:
            continue
        # stop after printing out k articles
        if k_counter >= k_nearest_neighbors:
            break
        k_counter += 1

        # print out the similar strings and their distances (only happens after next cell is executed)
        print(
            f"""
        --- Recommendation #{k_counter} (nearest neighbour {k_counter} of {k_nearest_neighbors}) ---
        Number: {i:2.0f}
        String: {strings[i]}
        Distance: {distances[i]:0.3f}"""
        )

    return indices_of_nearest_neighbors


### 5. Example recommendations

Let's look for articles similar to the first one, which was about Tony Blair.

In [None]:
article_descriptions = df["description"].tolist()

tony_blair_articles = print_recommendations_from_strings(
    strings=article_descriptions,  # let's base similarity on the article description (does this mean we ignore he article contents?)
    index_of_source_string=0,  # let's look at articles similar to the first one about Tony Blair
    k_nearest_neighbors=5,  # let's look at the 5 most similar articles (but how does this override the int = 1 assignment above?)
)

There are  2000 strings in the database, all with previously-calculated embeddings.
[0, 0.0253681538029239, 1.9600014517991347]
[0, 0.25761625996350557, 0.27101808868914845, 0.27640579866482584, 0.2716609724435485, 0.26769326170958374, 0.23804153608317613, 0.26782516357228925, 0.2667089596535165, 0.2552016889375752]
Returned in increasing order of distance but numbered by their position in the original string table 
 [   0  991 1044 1396  766 1064  131  434 1618  852] 

Rank    0; Distance coefficient: 0.000 &    0 of nearest neighbours
Rank    1; Distance coefficient: 0.153 &  991 of nearest neighbours
Rank    2; Distance coefficient: 0.160 & 1044 of nearest neighbours
Rank    3; Distance coefficient: 0.160 & 1396 of nearest neighbours
Rank    4; Distance coefficient: 0.171 &  766 of nearest neighbours
Rank    5; Distance coefficient: 0.173 & 1064 of nearest neighbours
Rank    6; Distance coefficient: 0.175 &  131 of nearest neighbours
Rank    7; Distance coefficient: 0.178 &  434 of 

Pretty good! 4 of the 5 recommendations explicitly mention Tony Blair and the fifth is an article from London about climate change, topics that might be often associated with Tony Blair.

Let's see how our recommender does on the second example article about NVIDIA's new chipset with more security.

In [69]:
chipset_security_articles = print_recommendations_from_strings(
    strings=article_descriptions,  # let's base similarity on the article description
    index_of_source_string=1,  # let's look at articles similar to the second one about a more secure chipset
    k_nearest_neighbors=5,  # let's look at the 5 or 10 most similar articles
)

There are  2000 strings in the database, all with previously-calculated embeddings.
[0.25761625996350557, 0, 0.28327987857317705, 0.22083238324438348, 0.20332350896619844, 0.28592724847340667, 0.279130550574168, 0.2829386152594774, 0.3050574747074891, 0.175832529454513]
Returned in increasing order of distance but numbered by their position in the original string table 
 [   1  158 1549  436 1088 1180   33  485    9  458] 

Rank    0; Distance coefficient: 0.000 &    1 of nearest neighbours
Rank    1; Distance coefficient: 0.112 &  158 of nearest neighbours
Rank    2; Distance coefficient: 0.145 & 1549 of nearest neighbours
Rank    3; Distance coefficient: 0.153 &  436 of nearest neighbours
Rank    4; Distance coefficient: 0.157 & 1088 of nearest neighbours
Rank    5; Distance coefficient: 0.168 & 1180 of nearest neighbours
Rank    6; Distance coefficient: 0.172 &   33 of nearest neighbours
Rank    7; Distance coefficient: 0.174 &  485 of nearest neighbours
Rank    8; Distance coeffici

From the printed distances, you can see that the #1 recommendation is much closer than all the others (0.11 vs 0.14+). And the #1 recommendation looks very similar to the starting article - it's another article from PC World about increasing computer security. Pretty good! 

## Appendix: Using embeddings in more sophisticated recommenders

A more sophisticated way to build a recommender system is to train a machine learning model that takes in tens or hundreds of signals, such as item popularity or user click data. Even in this system, embeddings can be a very useful signal into the recommender, especially for items that are being 'cold started' with no user data yet (e.g., a brand new product added to the catalog without any clicks yet).

## Appendix: Using embeddings to visualize similar articles

To get a sense of what our nearest neighbor recommender is doing, let's visualize the article embeddings. Although we can't plot the 2048 dimensions of each embedding vector, we can use techniques like [t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding) or [PCA](https://en.wikipedia.org/wiki/Principal_component_analysis) to compress the embeddings down into 2 or 3 dimensions, which we can chart.

Before visualizing the nearest neighbors, let's visualize all of the article descriptions using t-SNE. Note that t-SNE is not deterministic, meaning that results may vary from run to run.

In [70]:
# get embeddings for all article descriptions
embeddings = [embedding_from_string(string) for string in article_descriptions]
# compress the 2048-dimensional embeddings into 2 dimensions using t-SNE
tsne_components = tsne_components_from_embeddings(embeddings, 2)
# get the article labels for coloring the chart
labels = df["label"].tolist()

chart_from_components(
    components=tsne_components,
    labels=labels,
    strings=article_descriptions,
    width=600,
    height=500,
    title="t-SNE components of article descriptions",
)

As you can see in the chart above, even the highly compressed embeddings do a good job of clustering article descriptions by category. And it's worth emphasizing: this clustering is done with no knowledge of the labels themselves!

Also, if you look closely at the most egregious outliers, they are often due to mislabeling rather than poor embedding. For example, the majority of the blue World points in the green Sports cluster appear to be Sports stories.

Next, let's recolor the points by whether they are a source article, its nearest neighbors, or other.

In [71]:
# create labels for the recommended articles
def nearest_neighbor_labels(
    list_of_indices: list[int],
    k_nearest_neighbors: int = 5    # note this is merely a default that can be overridden by the call
) -> list[str]:
    """Return a list of labels to color the k nearest neighbors."""
    labels = ["Other" for _ in list_of_indices]
    source_index = list_of_indices[0]
    labels[source_index] = "Source"
    for i in range(k_nearest_neighbors):
        nearest_neighbor_index = list_of_indices[i + 1]
        labels[nearest_neighbor_index] = f"Nearest neighbour (top {k_nearest_neighbors})"
    return labels


tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=10)
chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5
)

In [75]:
# a 2D chart of nearest neighbors of the Tony Blair article
chart_from_components(
    components=tsne_components,
    labels=tony_blair_labels,
    strings=article_descriptions,
    width=800,
    height=700,
    title="Nearest neighbours of the Tony Blair article",
    category_orders={"label": ["Other", "Nearest (top 10)", "Source"]},
)

Looking at the 2D chart above, we can see that the articles about Tony Blair are somewhat close together inside of the World news cluster. Interestingly, although the 5 nearest neighbors (red) were closest in high dimensional space, they are not the closest points in this compressed 2D space. Compressing the embeddings down to 2 dimensions discards much of their information, and the nearest neighbors in the 2D space don't seem to be as relevant as those in the full embedding space.

In [73]:
# a 2D chart of nearest neighbors of the chipset security article
chart_from_components(
    components=tsne_components,
    labels=chipset_security_labels,
    strings=article_descriptions,
    width=800,
    height=600,
    title="Nearest neighbours of the chipset security article",
    category_orders={"label": ["Other", "Nearest neighbour (top 5)", "Source"]},
)

For the chipset security example, the 4 closest nearest neighbors in the full embedding space remain nearest neighbors in this compressed 2D visualization. The fifth is displayed as more distant, despite being closer in the full embedding space.

Should you want to, you can also make an interactive 3D plot of the embeddings with the function `chart_from_components_3D`. (Doing so will require recomputing the t-SNE components with `n_components=3`.)

In [78]:
# get embeddings for all article descriptions
# this is here to facilitate recalculation of the tsne_components with n_components = 3
embeddings = [embedding_from_string(string) for string in article_descriptions]
# compress the 2048-dimensional embeddings into 3 dimensions using t-SNE where we override the default n_components = 2
tsne_components = tsne_components_from_embeddings(embeddings, 3)
# get the article labels for coloring the chart
labels = df["label"].tolist()

# a 3D chart of nearest neighbors of the chipset security article
chart_from_components_3D(
    components=tsne_components,
    labels=tony_blair_labels,
    strings=article_descriptions,
    width=800,
    height=600,
    title="Nearest neighbours of the chipset security article",
    category_orders={"label": ["Other", "Nearest neighbours", "Source"]},
)