## Evaluate ChromaDB based Information Retrieval (IR)
based on all-MiniLM-L6-v2 embeddings


In [2]:
# provide project root path
ProjectRoot = "<PROVIDE PROJECT ROOT PATH>"
DatasetRoot = ProjectRoot + "/Dataset/"

#### Dependencies

In [3]:
try:
    import chromadb
except ImportError:
    !pip install chromadb

Collecting chromadb
  Downloading chromadb-0.5.5-py3-none-any.whl.metadata (6.8 kB)
Collecting chroma-hnswlib==0.7.6 (from chromadb)
  Downloading chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (252 bytes)
Collecting fastapi>=0.95.2 (from chromadb)
  Downloading fastapi-0.112.0-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn>=0.18.3 (from uvicorn[standard]>=0.18.3->chromadb)
  Downloading uvicorn-0.30.5-py3-none-any.whl.metadata (6.6 kB)
Collecting posthog>=2.4.0 (from chromadb)
  Downloading posthog-3.5.0-py2.py3-none-any.whl.metadata (2.0 kB)
Collecting onnxruntime>=1.14.1 (from chromadb)
  Downloading onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.3 kB)
Collecting opentelemetry-api>=1.2.0 (from chromadb)
  Downloading opentelemetry_api-1.26.0-py3-none-any.whl.metadata (1.4 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)
  Downloading opentelemetry_exporter_otlp_pro

In [4]:
import chromadb
import json
import regex as re
import numpy as np
import pandas as pd

#### Init IR

In [5]:
# loading full article from json file
with open(DatasetRoot + 'raw_knowledge.json', 'r') as f:
    raw_text_json = json.load(f)
raw_text_list = [text for _, text in raw_text_json.items()]

raw_text_ids = [idx for idx, _ in raw_text_json.items()]

In [6]:
# setup Chroma in-memory
client = chromadb.Client()

Retriever = client.create_collection("knowledge-store")
# Add docs to the knowledge store
Retriever.add(
    documents = raw_text_list,
    ids=raw_text_ids
)

/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:07<00:00, 11.1MiB/s]


#### Eval Retrieval Precision, Recall, MRR

In [7]:
# load context-question train set which was created by doc2query
train_df = pd.read_csv(DatasetRoot + '/q_a_trainset.csv')

In [8]:
def evaluate_retrieval(retriever, eval_dataset, top_n=3):
    precision_at_k = []
    recall_at_k = []
    mrr = []

    for _, eval_data in eval_dataset.iterrows():

        # get query and raw paragraph ID from where question was generated.
        # This paragraph will be treated as ground truth doc
        raw_para_id = eval_data['raw_para_id']
        query = eval_data['question']

        # search most relevant docs
        retrieved_docs = Retriever.query(query_texts= query, n_results=top_n)
        doc_indices = [int(idx) for idx in retrieved_docs['ids'][0]]

        relevant_docs = set([raw_para_id])
        retrieved_docs = set(doc_indices)

        hits = relevant_docs & retrieved_docs

        # calculate metrics
        precision = len(hits) / len(retrieved_docs)
        recall = len(hits) / len(relevant_docs)

        precision_at_k.append(precision)
        recall_at_k.append(recall)

        reciprocal_rank = 0.0
        for rank, doc_index in enumerate(doc_indices, start=1):
            if doc_index in relevant_docs:
                reciprocal_rank = 1.0 / rank
                break
        mrr.append(reciprocal_rank)

    avg_precision = np.mean(precision_at_k)
    avg_recall = np.mean(recall_at_k)
    avg_mrr = np.mean(mrr)

    return avg_precision, avg_recall, avg_mrr



In [9]:
precision, recall, mrr = evaluate_retrieval(Retriever, train_df[['raw_para_id', 'question']], top_n=1)

print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, MRR: {mrr:.4f}")

Precision: 0.4375, Recall: 0.4375, MRR: 0.4375


In [10]:
precision, recall, mrr = evaluate_retrieval(Retriever, train_df[['raw_para_id', 'question']], top_n=3)

print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, MRR: {mrr:.4f}")

Precision: 0.2396, Recall: 0.7188, MRR: 0.5625


In [11]:
precision, recall, mrr = evaluate_retrieval(Retriever, train_df[['raw_para_id', 'question']], top_n=5)

print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, MRR: {mrr:.4f}")

Precision: 0.1708, Recall: 0.8542, MRR: 0.5943
