# 03.1 - Embedding with Chroma

For finding semantically related documents, we'll use Chroma (https://www.trychroma.com/), which is a lightweight vector data store. Chroma supports swappable embedding models, filtering using metadata, keyword search, and multiple distance measurements. We'll use these features for evlauating approaches to organizing papers for downstream processing (search, summarization, keyword extraction, etc.).

The default Chroma embedding model is used in this notebook. The other "03" notebooks show how different embedding models can be used.

## Section 0 - Notebook Setup

In [None]:
%pip install --upgrade --quiet chromadb

In [None]:
%pip install --upgrade --quiet sentence_transformers

Load articles and prune ones without abstracts, since we're using the abstracts for generating the embeddings.

In [None]:
import pandas as pd
from genscai import paths

df_modeling_papers = pd.read_json(paths.data / "modeling_papers_0.json", orient="records", lines=True)
df_modeling_papers.shape

## Section I - Create a Vector Database using Chroma

Create a Chroma database for storing the vector data, and create a collection in the database. Chroma collections can each have their own embedding and distance measurements.

In [None]:
import chromadb
from genscai import paths

client = chromadb.PersistentClient(path=str(paths.output / "chroma_db"))

collection_name = "articles-default-embeddings"

collection = client.create_collection(name=collection_name)
# collection = client.get_collection(name=collection_name)
# collection = client.delete_collection(name=collection_name)

Add documents to the collection if the collection is new, or there are new documents to add.

In [None]:
from tqdm import tqdm

# use only the first 100 documents for testing
documents = df_modeling_papers.get("abstract").tolist()[:100]
ids = df_modeling_papers.get("id").tolist()[:100]

for i in tqdm(range(len(documents))):
    collection.add(documents=documents[i], ids=ids[i])

In [None]:
results = collection.query(query_texts=["agent-based models for malaria"], n_results=10)

results

Create a new collection using Cosign distance rather than Squred L2 (default). Ref: https://docs.trychroma.com/guides#changing-the-distance-function

In [8]:
collection_name = "articles-default-embeddings-cosign-distance"

collection = client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})

In [9]:
for i in tqdm(range(len(documents))):
    collection.add(documents=documents[i], ids=ids[i])

100%|██████████| 100/100 [02:25<00:00,  1.45s/it]


In [None]:
results = collection.query(
    query_texts=["agent-based model for malaria"],
    n_results=10,
)

results

{'ids': [['6c6ba8c664304e2e0bab2fedf0ab7172',
   'd09d1bf9621aa34f3def62f792832afb',
   '720774c69f06f1db3d52d62cd5c8dea1',
   'bd125f9c55d413d1cb7818e6e6193340',
   '338f54dd377e7ca36a88b0960325b332',
   '970de68ec0c738b559989f8a466900a1',
   'c470587524fc04e2f2cb569987ff2601',
   '7ad12941613f8c448d0e14b02b1f8931',
   '15b0cbbeca7b62fce9e34e0f1db428ec',
   'c92d823b6a0f1819d36c0f99067a22b4']],
 'embeddings': None,
 'documents': [['In February 2019, a major flooding event occurred in Townsville, North Queensland, Australia. Here we present a prediction of the occurrence of mosquito-borne diseases (MBDs) after the flooding. We used a mathematical modelling approach based on mosquito population abundance, survival, and size as well as current infectiousness to predict the changes in the occurrences of MBDs due to flooding in the study area. Based on 2019 year-to-date number of notifiable MBDs, we predicted an increase in number of cases, with a peak at 104 by one-half month after the fl