In [1]:
from dotenv import load_dotenv
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.vectorstores.utils import DistanceStrategy
from tqdm import tqdm 

import ir_datasets
import json
import openai
import os
import pickle

In [2]:
# Get Dataset

In [3]:
with open("./dataset/pubmedqa/ori_pqal.json", "r") as f:
    data = json.load(f)

In [4]:
data["21645374"]["QUESTION"]

'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?'

In [5]:
data["21645374"]["CONTEXTS"]

['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
 'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Window stage leaves were stained with the mitochondrial dye MitoTracker Red CMX

In [6]:
queries = {}
for k in data:
    queries[k] = {"text": data[k]["QUESTION"]}

In [7]:
with open("./dataset/pubmedqa/queries.pkl", "wb") as f:
    pickle.dump(queries, f)

In [8]:
docs = {}
for k in data:
    docs[k] = {"text": "\n".join([c for c in data[k]["CONTEXTS"]])}

In [9]:
with open("./dataset/pubmedqa/documents.pkl", "wb") as f:
    pickle.dump(docs, f)

In [10]:
rel_set = {}
for k in data:
    if k not in rel_set:
        rel_set[k] = []
    rel_set[k].append(k)

In [11]:
with open("./dataset/pubmedqa/relevance_set.pkl", "wb") as f:
    pickle.dump(rel_set, f)

In [12]:
# Get OpenAI Embeddings

In [13]:
load_dotenv()

True

In [14]:
client = openai.OpenAI(api_key = os.getenv("OPENAI_API_KEY"))

In [15]:
#### API CALL WARNING ####

def get_embedding(text, model="text-embedding-ada-002"):
    text = text.replace("\n", " ")
    response = client.embeddings.create(input=[text], model=model)
    if response and hasattr(response, "data") and response.data:
        embedding = response.data[0].embedding
        return embedding
    else:
        print("Invalid response or no embedding data received.")
        return None

In [16]:
# get_embedding(queries["21645374"]["text"])

In [17]:
for idx, query in tqdm(queries.items(), desc = "Generating Query Embeddings"):
    query_text = query["text"]
    queries[idx] = {"text": query_text, "embedding": get_embedding(query_text)}

Generating Query Embeddings: 100%|███████████████████████| 1000/1000 [02:50<00:00,  5.86it/s]


In [18]:
for doc_id in tqdm(docs, desc = "Generating Documents Embeddings"):
    combined_text =  docs[doc_id]["text"]
    docs[doc_id]["embedding"] = get_embedding(combined_text)

Generating Documents Embeddings: 100%|███████████████████| 1000/1000 [02:56<00:00,  5.67it/s]


In [19]:
query_file_path = "./openai_embeddings/pubmedqa/query_embeddings.pkl"
docs_file_path = "./openai_embeddings/pubmedqa/doc_embeddings.pkl"

In [21]:
with open(query_file_path, "wb") as f:
    pickle.dump(queries, f)

In [22]:
with open(docs_file_path, "wb") as f:
    pickle.dump(docs, f)

In [23]:
# Create VectorDB Index

In [25]:
with open(docs_file_path, "rb") as file:
    loaded_docs = pickle.load(file)
print("Document embeddings loaded successfully.")

Document embeddings loaded successfully.


In [26]:
data = []
for doc in loaded_docs:
    data.append((doc, loaded_docs[doc]["embedding"]))

In [27]:
faiss_vs = FAISS.from_embeddings(
    text_embeddings=data, 
    embedding=OpenAIEmbeddings(),
    distance_strategy=DistanceStrategy.DOT_PRODUCT)

In [28]:
faiss_vs.save_local("./vectordb/faiss/pubmedqa/")