# 03.1 - Embedding with Chroma

For finding semantically related documents, we'll use Chroma (https://www.trychroma.com/), which is a lightweight vector data store. Chroma supports swappable embedding models, filtering using metadata, keyword search, and multiple distance measurements. We'll use these features for evlauating approaches to organizing papers for downstream processing (search, summarization, keyword extraction, etc.).

The default Chroma embedding model is used in this notebook. The other "03" notebooks show how different embedding models can be used.

## Section 0 - Notebook Setup

In [None]:
%pip install --upgrade --quiet chromadb

In [None]:
%pip install --upgrade --quiet sentence_transformers

Load articles and prune ones without abstracts, since we're using the abstracts for generating the embeddings.

In [None]:
import pandas as pd
from genscai import paths

df_modeling_papers = pd.read_json(paths.data / "modeling_papers_0.json", orient="records", lines=True)
df_modeling_papers.shape

## Section I - Create a Vector Database using Chroma

Create a Chroma database for storing the vector data, and create a collection in the database. Chroma collections can each have their own embedding and distance measurements.

In [None]:
import chromadb
from genscai import paths

client = chromadb.PersistentClient(path=str(paths.output / "chroma_db"))

collection_name = "papers-default-embeddings"

collection = client.create_collection(name=collection_name)
# collection = client.get_collection(name=collection_name)
# collection = client.delete_collection(name=collection_name)

Add documents to the collection if the collection is new, or there are new documents to add.

In [None]:
from tqdm import tqdm

# use only the first 100 documents for testing
documents = df_modeling_papers.get("abstract").tolist()[:100]
ids = df_modeling_papers.get("id").tolist()[:100]

for i in tqdm(range(len(documents))):
    collection.add(documents=documents[i], ids=ids[i])

In [None]:
results = collection.query(query_texts=["agent-based models for malaria"], n_results=10)
results

Create a new collection using Cosign distance rather than Squred L2 (default). Ref: https://docs.trychroma.com/guides#changing-the-distance-function

In [None]:
collection_name = "papers-default-embeddings-cosign-distance"

collection = client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})
# collection = client.get_collection(name=collection_name)
# collection = client.delete_collection(name=collection_name)

In [None]:
for i in tqdm(range(len(documents))):
    collection.add(documents=documents[i], ids=ids[i])

In [None]:
results = collection.query(query_texts=["agent-based model for malaria"], n_results=10)
results