In [27]:
import os
import openai
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from dotenv import load_dotenv
from nltk.tokenize import word_tokenize
from tqdm import tqdm
import pickle
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from operator import itemgetter
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import FAISS
from langchain.vectorstores.utils import DistanceStrategy

In [2]:
load_dotenv()

# Uncomment to use the default DaVinci Model 
# llm = OpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"))
# model = 'text-davinci-003'

# Uncomment to use the gpt-3.5-turbo-instruct model 
llm = OpenAI(model_name='gpt-3.5-turbo-instruct', openai_api_key = os.getenv("OPENAI_API_KEY"))
model = 'gpt-3.5-turbo-instruct'

In [9]:
import ir_datasets


dataset = ir_datasets.load("beir/nfcorpus/test")


queries = {}
for query in dataset.queries_iter():
    queries[query.query_id] = {"text":query.text}

docs = {}
count = 0
for doc in dataset.docs_iter():
    docs[doc.doc_id] = {"text": doc.text}
    count += 1

rel_set = {}
for qrel in dataset.qrels_iter():
    if qrel.query_id not in rel_set:
        rel_set[qrel.query_id] = []
    if qrel.relevance > 0: 
        rel_set[qrel.query_id].append(qrel.doc_id)


In [11]:
queries['PLAIN-2'], len(queries)

({'text': 'Do Cholesterol Statin Drugs Cause Breast Cancer?'}, 323)

In [10]:
docs['MED-10'], len(docs)

({'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died, of which 3,619 (60.2%) was due to breast cancer. After adjustment for age, tumor 

In [12]:
#### API CALL WARNING ####

client = openai.OpenAI(api_key = os.getenv("OPENAI_API_KEY"))

def get_embedding(text, model="text-embedding-ada-002"):
    text = text.replace("\n", " ")
    response = client.embeddings.create(input=[text], model=model)
    if response and hasattr(response, 'data') and response.data:
        embedding = response.data[0].embedding
        return embedding
    else:
        print("Invalid response or no embedding data received.")
        return None

In [13]:
for idx, query in tqdm(queries.items(), desc = 'Generating Embeddings'):
    query_text = query['text']
    queries[idx] = {'text': query_text, 'embedding': get_embedding(query_text)}

Generating Embeddings: 100%|███████████████████████████████| 323/323 [00:57<00:00,  5.57it/s]


In [14]:
for doc_id in tqdm(docs, desc = 'Generating Embeddings'):
    combined_text =  docs[doc_id]['text']
    docs[doc_id]['embedding'] = get_embedding(combined_text)

Generating Embeddings: 100%|█████████████████████████████| 3633/3633 [10:45<00:00,  5.62it/s]


In [15]:
docs_file_path = './backups/openai_embeddings/doc_embeddings_nfcorpus.pkl'
query_file_path = './backups/openai_embeddings/query_embeddings_nfcorpus.pkl'

In [16]:
with open(docs_file_path, 'wb') as file:
    pickle.dump(docs, file)
print(f"Embeddings saved to {docs_file_path}")

Embeddings saved to ./backups/openai_embeddings/doc_embeddings_nfcorpus.pkl


In [17]:
with open(query_file_path, 'wb') as file:
    pickle.dump(queries, file)
print(f"Embeddings saved to {query_file_path}")

Embeddings saved to ./backups/openai_embeddings/query_embeddings_nfcorpus.pkl


In [18]:
with open(docs_file_path, 'rb') as file:
    loaded_docs = pickle.load(file)
print("Document embeddings loaded successfully.")

Document embeddings loaded successfully.


In [19]:
with open(query_file_path, 'rb') as file:
    loaded_queries = pickle.load(file)

In [20]:
queries = loaded_queries
print("Query embeddings loaded successfully.")

Query embeddings loaded successfully.


In [24]:
loaded_docs["MED-10"].keys()

dict_keys(['text', 'embedding'])

In [25]:
annoy_data = []
for doc in loaded_docs:
    annoy_data.append((doc, loaded_docs[doc]["embedding"]))

In [28]:
faiss_vs = FAISS.from_embeddings(
    text_embeddings=annoy_data, 
    embedding=OpenAIEmbeddings(),
    distance_strategy=DistanceStrategy.DOT_PRODUCT)

In [29]:
faiss_vs.save_local("./backups/nfcorpus/faiss/")

In [30]:
# Load index from file
loaded_faiss_vs = FAISS.load_local(
    folder_path="./backups/nfcorpus/faiss/",
    embeddings=OpenAIEmbeddings())

retriever = loaded_faiss_vs.as_retriever(search_kwargs={'k': 10})

# Define the RAG pipeline
template = """
Answer the question or Explain the topic given this additional context: {context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

In [36]:
def format_docs(_docs):
    ls = []
    for doc in _docs:
        if doc.page_content in docs:
            ls.append(docs[doc.page_content]["text"][:400])
    return ls
    # return [docs[int(doc.page_content)]["body"] for doc in _docs]

In [37]:
chain = ({"context": retriever | format_docs, "question": RunnablePassthrough()} 
         | prompt 
         | llm 
         | StrOutputParser())

In [34]:
queries['PLAIN-2']['text']

'Do Cholesterol Statin Drugs Cause Breast Cancer?'

In [33]:
chain.invoke(queries['PLAIN-2']['text'])

'\nAnswer: The available evidence on the association between statin use and breast cancer risk is conflicting. Some studies suggest that statins may decrease the risk of breast cancer, while others have found no significant effect. A recent meta-analysis of observational studies did not find any evidence to support the hypothesis that statins have a protective effect against breast cancer. However, there have been some studies that have found an increased risk of breast cancer in long-term users of statins. Further research, including randomized clinical trials, is needed to confirm this association and understand the underlying biological mechanisms. Additionally, there is evidence that low cholesterol levels may be associated with a lower risk of breast cancer, but more research is needed to understand this relationship. Ultimately, the role of cholesterol statin drugs in breast cancer risk remains unclear and more studies are needed to fully understand their impact.'

In [38]:
#### API CALL WARNING #####

rag_responses = {}
loq = []
count  = 0
# Run RAG pipeline for every question
for query_id in tqdm(rel_set.keys(), desc = 'Asking Queries to ChatGPT with RAG'):
    # if count > 100:
    #     break
    query_text = queries[query_id]['text']
    response = chain.invoke(query_text)
    rag_responses[query_id] = response
    count+=1

Asking Queries to ChatGPT with RAG: 100%|██████████████████| 323/323 [09:12<00:00,  1.71s/it]


In [39]:
#### DUMP OVERWRITE WARNING ####

rag_responses_file_path = './backups/openai_with_rag_responses_nf_corpus_' + model + '.pkl'
with open(rag_responses_file_path, 'wb') as file:
    pickle.dump(rag_responses, file)

print(f"RAG responses saved to {rag_responses_file_path}")

RAG responses saved to ./backups/openai_with_rag_responses_nf_corpus_gpt-3.5-turbo-instruct.pkl


In [58]:
openai_with_rag_responses_file_path = './backups/openai_with_rag_responses_nf_corpus_' + model + '.pkl'
with open(openai_with_rag_responses_file_path, 'rb') as file:
    rag_responses = pickle.load(file)

In [40]:
 # Sanity
rag_responses['PLAIN-2']

'\nAnswer: Based on the emerging evidence and conflicting results from previous studies, it is unclear if cholesterol statin drugs directly cause breast cancer. Some studies suggest that statins may decrease the risk of breast cancer, while others have shown no significant association. However, further research is needed to fully understand the potential effects of statins on breast cancer risk. '

In [41]:
# Implement BLEU evaluation function
def compute_bleu(references, candidate):
    smoothing = SmoothingFunction().method5
    return sentence_bleu(references, candidate, smoothing_function=smoothing)

# Implement ROUGE evaluation function
def compute_rouge(references, candidate):
    scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)
    total_score = 0

    # Compute ROUGE for each reference
    for reference in references:
        scores = scorer.score(reference, candidate)
        total_score += scores['rouge1'].fmeasure

    # Calculate average score
    average_score = total_score / len(references)
    return average_score

# Evaluate BLEU and ROUGE for each query

K = 15 # Number of most relevant docs to consider for scoring performance
total_bleu_score = 0.0
total_rouge_score = 0.0
num_queries = 0

for query_id, relevant_docs in rel_set.items():
    query_text = queries[query_id]['text']
    response = rag_responses[query_id]

    # print(query_id, "\n\n", query_text, "\n\nResponse:\n", response, "\nTopmost relevant Doc:\n", docs[relevant_docs[0]]['text'], "\n======\n")
    
    # Evaluate using BLEU
    bleu_score = compute_bleu([docs[id]['text'] for id in relevant_docs[:K]], response)
    total_bleu_score += bleu_score

    # Evaluate using ROUGE
    rouge_score = compute_rouge([docs[id]['text'] for id in relevant_docs[:K]], response)
    total_rouge_score += rouge_score

    num_queries += 1
    if num_queries == 101:
        break

# Calculate mean scores
mean_bleu_score = total_bleu_score / num_queries
mean_rouge_score = total_rouge_score / num_queries

print(f"Mean BLEU Score: {mean_bleu_score:.4f}")
print(f"Mean ROUGE Score: {mean_rouge_score:.4f}")

Mean BLEU Score: 0.6099
Mean ROUGE Score: 0.2024


In [18]:
# da-vinci
# Mean BLEU Score: 0.8224
# Mean ROUGE Score: 0.2105

# Mean BLEU Score: 0.8377
# Mean ROUGE Score: 0.2226

# gpt-3.5-turbo-instruct

# Mean BLEU Score: 0.7869
# Mean ROUGE Score: 0.2407

# Mean BLEU Score: 0.7937
# Mean ROUGE Score: 0.2425
