# Search

In [2]:
import pandas as pd
import os
from openai import AzureOpenAI
import numpy as np

In [25]:
df = pd.read_csv('data/text.csv')
df

Unnamed: 0,content
0,Elephants are the largest land animals on Eart...
1,Giraffes are the tallest living terrestrial an...
2,"Lions, often referred to as the 'king of beast..."


In [8]:
openai_client: AzureOpenAI = AzureOpenAI(
    api_version = "2024-06-01",
    max_retries=5
)

def get_embedding(text):
    response = openai_client.embeddings.create(input=text, model="text-embedding-ada-002")
    return response.data[0].embedding

In [9]:
df["embedding"] = df["content"].apply(lambda x: get_embedding(x))

In [17]:
question = "Wat kan je me vertellen over leeuwen?" #What can you tell me about lions? (Dutch)

Cosine similarity is a way to measure how similar two things are by looking at the angle between them. Imagine you have two arrows pointing in different directions. If the arrows point in the same direction, the angle between them is small, and they are very similar. If they point in completely opposite directions, the angle is large, and they are very different.

In more technical terms, cosine similarity is often used to compare text documents. Each document is turned into a list of numbers (a vector), and cosine similarity measures the angle between these vectors. The smaller the angle, the more similar the documents are.

In [14]:
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

In [15]:
def search(df, question, n=1):
    #Convert the question in an embedding
    q_embedding = get_embedding(question)

    #Calculate the similarity between the question and the embeddings in the dataframe
    df["similarity"] = df["embedding"].apply(lambda x: cosine_similarity(x, q_embedding))

    #Return the top n most similar rows
    res = df.sort_values("similarity", ascending=False).head(n)
    return res

In [18]:
search(df, question)

Unnamed: 0,content,embedding,similarity
1,Giraffes are the tallest living terrestrial an...,"[-0.015602780506014824, -0.003377889283001423,...",0.804679
