In [None]:
from langchain_community.vectorstores import Chroma
import sys
sys.path.insert(0, "..")
import constants
from langchain_openai import OpenAIEmbeddings
import torch
from tqdm import tqdm

In [None]:
import os
os.environ["OPENAI_API_KEY"] = constants.OPENAI_API

In [None]:
TEST_QUERIES = [
    "Retrieve information about gene AICDA, DLBCL (diffuse large B-cell lymphoma) and FL (follicular lymphoma), especially in the context of AICDA's relevance to DLBCL and FL.",
    "Retrieve information about gene BCL6, DLBCL (diffuse large B-cell lymphoma) and MCL (mantle cell lymphoma), especially in the context of BCL6's relevance to DLBCL and MCL.",
    "Retrieve information about gene AASS, cHL (Classical Hodgkin Lymphoma) and MCL (mantle cell lymphoma), especially in the context of AASS's relevance to MCL and cHL."
]
DEVICE="cuda:7"

In [None]:
embeddings = OpenAIEmbeddings()
vectorstore = Chroma(persist_directory="../" + constants.OMIM_PERSIST_DIRECTORY, embedding_function=embeddings)

In [None]:
test_embeds = torch.Tensor(embeddings.embed_documents(TEST_QUERIES)).to(DEVICE)

In [None]:
all_docs = vectorstore._collection.get(include=[])

In [None]:
all_ids = all_docs["ids"]

In [None]:
BATCH_SIZE = 10000

In [None]:
ground_truth = []
for j in range(test_embeds.shape[0]):
    test = test_embeds[j, :]
    cosine = torch.zeros(len(all_ids), device=DEVICE)
    for i in tqdm(range(0, len(all_ids), BATCH_SIZE)):
        ids = all_ids[i:i+BATCH_SIZE]
        embed = torch.from_numpy(vectorstore._collection.get(ids, include=["embeddings"])["embeddings"]).to(DEVICE)

        cosine[i:i+BATCH_SIZE] = torch.sum(test * embed, dim=1) / (torch.sum(embed.square(), dim=1).sqrt() * torch.norm(test))
    argsort = torch.argsort(cosine, descending=True)
    ground_truth.append([all_ids[i] for i in argsort])

In [None]:
K = 5

In [None]:
queried = vectorstore._collection.query(query_texts=TEST_QUERIES, query_embeddings=test_embeds.tolist(), n_results=K)["ids"]

In [None]:
recalls = []
for j in range(len(TEST_QUERIES)):
    recalls.append(len(set(ground_truth[j][:K]).intersection(set(queried[j]))) / K)

In [None]:
sum(recalls) / len(recalls)

In [None]:
recalls