In [3]:
import openai
import pandas as pd
import os
import wget
from ast import literal_eval

# Chroma's client library for Python
import chromadb

# I've set this to our new embeddings model, this can be changed to the embedding model of your choice
EMBEDDING_MODEL = "text-embedding-ada-002"

# Ignore unclosed SSL socket warnings - optional in case you get these errors
import warnings

warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [4]:
embeddings_url = 'https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip'

# The file is ~700 MB so this will take some time
wget.download(embeddings_url)

'vector_database_wikipedia_articles_embedded.zip'

In [5]:
import zipfile
with zipfile.ZipFile("vector_database_wikipedia_articles_embedded.zip","r") as zip_ref:
    zip_ref.extractall("../data")

In [6]:
article_df = pd.read_csv('../data/vector_database_wikipedia_articles_embedded.csv')

In [7]:
article_df.head()

Unnamed: 0,id,url,title,text,title_vector,content_vector,vector_id
0,1,https://simple.wikipedia.org/wiki/April,April,April is the fourth month of the year in the J...,"[0.001009464613161981, -0.020700545981526375, ...","[-0.011253940872848034, -0.013491976074874401,...",0
1,2,https://simple.wikipedia.org/wiki/August,August,August (Aug.) is the eighth month of the year ...,"[0.0009286514250561595, 0.000820168002974242, ...","[0.0003609954728744924, 0.007262262050062418, ...",1
2,6,https://simple.wikipedia.org/wiki/Art,Art,Art is a creative activity that expresses imag...,"[0.003393713850528002, 0.0061537534929811954, ...","[-0.004959689453244209, 0.015772193670272827, ...",2
3,8,https://simple.wikipedia.org/wiki/A,A,A or a is the first letter of the English alph...,"[0.0153952119871974, -0.013759135268628597, 0....","[0.024894846603274345, -0.022186409682035446, ...",3
4,9,https://simple.wikipedia.org/wiki/Air,Air,Air refers to the Earth's atmosphere. Air is a...,"[0.02224554680287838, -0.02044147066771984, -0...","[0.021524671465158463, 0.018522677943110466, -...",4


In [8]:
# Read vectors from strings back into a list
article_df['title_vector'] = article_df.title_vector.apply(literal_eval)
article_df['content_vector'] = article_df.content_vector.apply(literal_eval)

# Set vector_id to be a string
article_df['vector_id'] = article_df['vector_id'].apply(str)

In [9]:
article_df.info(show_counts=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25000 entries, 0 to 24999
Data columns (total 7 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   id              25000 non-null  int64 
 1   url             25000 non-null  object
 2   title           25000 non-null  object
 3   text            25000 non-null  object
 4   title_vector    25000 non-null  object
 5   content_vector  25000 non-null  object
 6   vector_id       25000 non-null  object
dtypes: int64(1), object(6)
memory usage: 1.3+ MB


In [10]:
chroma_client = chromadb.EphemeralClient() # Equivalent to chromadb.Client(), ephemeral.
# Uncomment for persistent client
# chroma_client = chromadb.PersistentClient()

In [11]:
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction

# Test that your OpenAI API key is correctly set as an environment variable
# Note. if you run this notebook locally, you will need to reload your terminal and the notebook for the env variables to be live.

# Note. alternatively you can set a temporary env variable like this:
# os.environ["OPENAI_API_KEY"] = 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'

if os.getenv("OPENAI_API_KEY") is not None:
    openai.api_key = os.getenv("OPENAI_API_KEY")
    print ("OPENAI_API_KEY is ready")
else:
    print ("OPENAI_API_KEY environment variable not found")


embedding_function = OpenAIEmbeddingFunction(api_key=os.environ.get('OPENAI_API_KEY'), model_name=EMBEDDING_MODEL)

wikipedia_content_collection = chroma_client.create_collection(name='wikipedia_content', embedding_function=embedding_function)
wikipedia_title_collection = chroma_client.create_collection(name='wikipedia_titles', embedding_function=embedding_function)

OPENAI_API_KEY is ready


In [14]:
# Define the maximum batch size
max_batch_size = 5461

# Split your data into smaller batches
for start_idx in range(0, len(article_df), max_batch_size):
    end_idx = start_idx + max_batch_size
    batch_ids = article_df.vector_id.tolist()[start_idx:end_idx]
    batch_embeddings = article_df.content_vector.tolist()[start_idx:end_idx]

    # Add each batch to the collection
    wikipedia_content_collection.add(
        ids=batch_ids,
        embeddings=batch_embeddings
    )


In [16]:
# Define the maximum batch size
max_batch_size = 5461

# Split your data into smaller batches and add them to the collection
for start_idx in range(0, len(article_df), max_batch_size):
    end_idx = start_idx + max_batch_size
    batch_ids = article_df.vector_id.tolist()[start_idx:end_idx]
    batch_embeddings = article_df.title_vector.tolist()[start_idx:end_idx]

    # Add each batch to the title collection
    wikipedia_title_collection.add(
        ids=batch_ids,
        embeddings=batch_embeddings
    )


In [17]:
def query_collection(collection, query, max_results, dataframe):
    results = collection.query(query_texts=query, n_results=max_results, include=['distances']) 
    df = pd.DataFrame({
                'id':results['ids'][0], 
                'score':results['distances'][0],
                'title': dataframe[dataframe.vector_id.isin(results['ids'][0])]['title'],
                'content': dataframe[dataframe.vector_id.isin(results['ids'][0])]['text'],
                })
    
    return df

In [18]:
title_query_result = query_collection(
    collection=wikipedia_title_collection,
    query="modern art in Europe",
    max_results=10,
    dataframe=article_df
)
title_query_result.head()

Unnamed: 0,id,score,title,content
116,12249,0.265009,Europe,Europe is the western part of the continent of...
1332,12248,0.29057,European,European may mean:\nA person or attribute of t...
2885,12225,0.314753,Scandinavia,Scandinavia is a group of countries in norther...
12212,1332,0.317153,Western civilization,"Western civilization, western culture or the ..."
12216,12216,0.32114,Eastern Europe,Eastern Europe is the eastern region of Europe...


In [21]:
content_query_result = query_collection(
    collection=wikipedia_content_collection,
    query="Albanias capital",
    max_results=10,
    dataframe=article_df
)
content_query_result.head()

Unnamed: 0,id,score,title,content
789,7042,0.234331,Albania,"Albania ( ; ), officially called the Republic ..."
3068,789,0.280495,Algiers,"Algiers is the capital city of Algeria, which ..."
5761,16423,0.336675,"Albany, New York",Albany ( ) is the capital city of the U.S. sta...
6392,9397,0.363711,Palermo,Palermo is an Italian city. It is the capital ...
7042,5761,0.369382,Tirana,Tirana ( or Tirana) is the capital city of Rep...


In [22]:
client = chromadb.PersistentClient(path='chromadb')